Skip to content

Commit

Permalink
refactor: keep same api as v0.14.0 for SplitFirst/SplitLast (#271)
Browse files Browse the repository at this point in the history
* refactor: keep same api for SplitFirst/SplitLast

Not worth breaking this function, and it saves us some pain with
updating downstream deps

* add defensive copy
  • Loading branch information
MarcoPolo authored Feb 21, 2025
1 parent 4abf520 commit 2ac523b
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 42 deletions.
89 changes: 70 additions & 19 deletions component.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,39 @@ type Component struct {
valueStartIdx int // Index of the first byte of the Component's value in the bytes array
}

func (c Component) AsMultiaddr() Multiaddr {
func (c *Component) AsMultiaddr() Multiaddr {
if c.Empty() {
return nil
}
return []Component{c}
return []Component{*c}
}

func (c Component) Encapsulate(o Multiaddr) Multiaddr {
func (c *Component) Encapsulate(o Multiaddr) Multiaddr {
return c.AsMultiaddr().Encapsulate(o)
}

func (c Component) Decapsulate(o Multiaddr) Multiaddr {
func (c *Component) Decapsulate(o Multiaddr) Multiaddr {
return c.AsMultiaddr().Decapsulate(o)
}

func (c Component) Empty() bool {
func (c *Component) Empty() bool {
if c == nil {
return true
}
return len(c.bytes) == 0
}

func (c Component) Bytes() []byte {
func (c *Component) Bytes() []byte {
if c == nil {
return nil
}
return []byte(c.bytes)
}

func (c Component) MarshalBinary() ([]byte, error) {
func (c *Component) MarshalBinary() ([]byte, error) {
if c == nil {
return nil, errNilPtr
}
return c.Bytes(), nil
}

Expand All @@ -58,7 +67,10 @@ func (c *Component) UnmarshalBinary(data []byte) error {
return nil
}

func (c Component) MarshalText() ([]byte, error) {
func (c *Component) MarshalText() ([]byte, error) {
if c == nil {
return nil, errNilPtr
}
return []byte(c.String()), nil
}

Expand All @@ -79,7 +91,10 @@ func (c *Component) UnmarshalText(data []byte) error {
return nil
}

func (c Component) MarshalJSON() ([]byte, error) {
func (c *Component) MarshalJSON() ([]byte, error) {
if c == nil {
return nil, errNilPtr
}
txt, err := c.MarshalText()
if err != nil {
return nil, err
Expand All @@ -101,22 +116,40 @@ func (c *Component) UnmarshalJSON(data []byte) error {
return c.UnmarshalText([]byte(v))
}

func (c Component) Equal(o Component) bool {
func (c *Component) Equal(o *Component) bool {
if c == nil || o == nil {
return c == o
}
return c.bytes == o.bytes
}

func (c Component) Compare(o Component) int {
func (c *Component) Compare(o *Component) int {
if c == nil && o == nil {
return 0
}
if c == nil {
return -1
}
if o == nil {
return 1
}
return strings.Compare(c.bytes, o.bytes)
}

func (c Component) Protocols() []Protocol {
func (c *Component) Protocols() []Protocol {
if c == nil {
return nil
}
if c.protocol == nil {
return nil
}
return []Protocol{*c.protocol}
}

func (c Component) ValueForProtocol(code int) (string, error) {
func (c *Component) ValueForProtocol(code int) (string, error) {
if c == nil {
return "", fmt.Errorf("component is nil")
}
if c.protocol == nil {
return "", fmt.Errorf("component has nil protocol")
}
Expand All @@ -126,18 +159,27 @@ func (c Component) ValueForProtocol(code int) (string, error) {
return c.Value(), nil
}

func (c Component) Protocol() Protocol {
func (c *Component) Protocol() Protocol {
if c == nil {
return Protocol{}
}
if c.protocol == nil {
return Protocol{}
}
return *c.protocol
}

func (c Component) RawValue() []byte {
func (c *Component) RawValue() []byte {
if c == nil {
return nil
}
return []byte(c.bytes[c.valueStartIdx:])
}

func (c Component) Value() string {
func (c *Component) Value() string {
if c == nil {
return ""
}
if c.Empty() {
return ""
}
Expand All @@ -146,7 +188,10 @@ func (c Component) Value() string {
return value
}

func (c Component) valueAndErr() (string, error) {
func (c *Component) valueAndErr() (string, error) {
if c == nil {
return "", errNilPtr
}
if c.protocol == nil {
return "", fmt.Errorf("component has nil protocol")
}
Expand All @@ -160,15 +205,21 @@ func (c Component) valueAndErr() (string, error) {
return value, nil
}

func (c Component) String() string {
func (c *Component) String() string {
if c == nil {
return "<nil component>"
}
var b strings.Builder
c.writeTo(&b)
return b.String()
}

// writeTo is an efficient, private function for string-formatting a multiaddr.
// Trust me, we tend to allocate a lot when doing this.
func (c Component) writeTo(b *strings.Builder) {
func (c *Component) writeTo(b *strings.Builder) {
if c == nil {
return
}
if c.protocol == nil {
return
}
Expand Down
19 changes: 14 additions & 5 deletions multiaddr.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ var errNilPtr = errors.New("nil ptr")
// Multiaddr is the data structure representing a Multiaddr
type Multiaddr []Component

func (m Multiaddr) copy() Multiaddr {
if m == nil {
return nil
}
out := make(Multiaddr, len(m))
copy(out, m)
return out
}

func (m Multiaddr) Empty() bool {
if len(m) == 0 {
return true
Expand Down Expand Up @@ -71,7 +80,7 @@ func (m Multiaddr) Equal(m2 Multiaddr) bool {
return false
}
for i, c := range m {
if !c.Equal(m2[i]) {
if !c.Equal(&m2[i]) {
return false
}
}
Expand All @@ -80,7 +89,7 @@ func (m Multiaddr) Equal(m2 Multiaddr) bool {

func (m Multiaddr) Compare(o Multiaddr) int {
for i := 0; i < len(m) && i < len(o); i++ {
if cmp := m[i].Compare(o[i]); cmp != 0 {
if cmp := m[i].Compare(&o[i]); cmp != 0 {
return cmp
}
}
Expand Down Expand Up @@ -177,13 +186,13 @@ func (m Multiaddr) Encapsulate(o Multiaddr) Multiaddr {
return Join(m, o)
}

func (m Multiaddr) EncapsulateC(c Component) Multiaddr {
func (m Multiaddr) EncapsulateC(c *Component) Multiaddr {
if c.Empty() {
return m
}
out := make([]Component, 0, len(m)+1)
out = append(out, m...)
out = append(out, c)
out = append(out, *c)
return out
}

Expand All @@ -200,7 +209,7 @@ func (m Multiaddr) Decapsulate(rightParts Multiaddr) Multiaddr {
break
}

foundMatch = rightC.Equal(leftParts[i+j])
foundMatch = rightC.Equal(&leftParts[i+j])
if !foundMatch {
break
}
Expand Down
50 changes: 45 additions & 5 deletions multiaddr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ func TestReturnsNilOnEmpty(t *testing.T) {
a, _ = SplitLast(a)
require.Nil(t, a)

a, c := SplitLast(nil)
require.Zero(t, len(a.Protocols()))
require.Nil(t, a)
require.Nil(t, c)
require.True(t, c.Empty())

// Test that empty multiaddr from various operations returns nil
a = StringCast("/ip4/1.2.3.4/tcp/1234")
_, a = SplitFirst(a)
Expand All @@ -36,6 +42,11 @@ func TestReturnsNilOnEmpty(t *testing.T) {
_, a = SplitFirst(a)
require.Nil(t, a)

c, a = SplitFirst(nil)
require.Nil(t, a)
require.Nil(t, c)
require.True(t, c.Empty())

a = StringCast("/ip4/1.2.3.4/tcp/1234")
a = a.Decapsulate(a)
require.Nil(t, a)
Expand Down Expand Up @@ -400,7 +411,7 @@ func TestBytesSplitAndJoin(t *testing.T) {

for i, a := range split {
if a.String() != res[i] {
t.Errorf("split component failed: %s != %s", a, res[i])
t.Errorf("split component failed: %s != %s", &a, res[i])
}
}

Expand All @@ -411,7 +422,7 @@ func TestBytesSplitAndJoin(t *testing.T) {

for i, a := range split {
if a.String() != res[i] {
t.Errorf("split component failed: %s != %s", a, res[i])
t.Errorf("split component failed: %s != %s", &a, res[i])
}
}
}
Expand Down Expand Up @@ -863,7 +874,7 @@ func TestComponentBinaryMarshaler(t *testing.T) {
if err = comp2.UnmarshalBinary(b); err != nil {
t.Fatal(err)
}
if !comp.Equal(comp2) {
if !comp.Equal(&comp2) {
t.Error("expected equal components in circular marshaling test")
}
}
Expand All @@ -882,7 +893,7 @@ func TestComponentTextMarshaler(t *testing.T) {
if err = comp2.UnmarshalText(b); err != nil {
t.Fatal(err)
}
if !comp.Equal(comp2) {
if !comp.Equal(&comp2) {
t.Error("expected equal components in circular marshaling test")
}
}
Expand All @@ -901,7 +912,7 @@ func TestComponentJSONMarshaler(t *testing.T) {
if err = comp2.UnmarshalJSON(b); err != nil {
t.Fatal(err)
}
if !comp.Equal(comp2) {
if !comp.Equal(&comp2) {
t.Error("expected equal components in circular marshaling test")
}
}
Expand All @@ -914,6 +925,9 @@ func TestUseNil(t *testing.T) {
_ = f()

var foo Multiaddr = nil
_, right := SplitFirst(foo)
right.Protocols()
foo.Protocols()
foo.Bytes()
foo.Compare(nil)
foo.Decapsulate(nil)
Expand All @@ -930,6 +944,32 @@ func TestUseNil(t *testing.T) {
_, _ = foo.ValueForProtocol(0)
}

func TestUseNilComponent(t *testing.T) {
var foo *Component
foo.AsMultiaddr()
foo.Encapsulate(nil)
foo.Decapsulate(nil)
foo.Empty()
foo.Bytes()
foo.MarshalBinary()
foo.MarshalJSON()
foo.MarshalText()
foo.UnmarshalBinary(nil)
foo.UnmarshalJSON(nil)
foo.UnmarshalText(nil)
foo.Equal(nil)
foo.Compare(nil)
foo.Protocols()
foo.ValueForProtocol(0)
foo.Protocol()
foo.RawValue()
foo.Value()
_ = foo.String()

var m Multiaddr = nil
m.EncapsulateC(foo)
}

func TestFilterAddrs(t *testing.T) {
bad := []Multiaddr{
newMultiaddr(t, "/ip6/fe80::1/tcp/1234"),
Expand Down
Loading

0 comments on commit 2ac523b

Please sign in to comment.