Skip to content

Commit

Permalink
refactor: Backwards compatible Encapsulate/Decapsulate/Join/NewCompon…
Browse files Browse the repository at this point in the history
…ent (#272)

* Backwards compatible Encapsulate/Decapsulate/Join

* remove EncapsulateC

* feat: Add Multiaddr.AppendComponent

* remove .Empty (#274)

prefer the simpler `len` check for Multiaddr and `== nil` for
`*Component`.

* export AsMultiaddrer interface

* rename AsMultiaddr to Multiaddr

Export Multiaddrer interface
  • Loading branch information
MarcoPolo authored Feb 24, 2025
1 parent 2ac523b commit 4d1f355
Show file tree
Hide file tree
Showing 10 changed files with 197 additions and 152 deletions.
33 changes: 21 additions & 12 deletions codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,31 +67,36 @@ func stringToBytes(s string) ([]byte, error) {
return b.Bytes(), nil
}

func readComponent(b []byte) (int, Component, error) {
func readComponent(b []byte) (int, *Component, error) {
var offset int
code, n, err := ReadVarintCode(b)
if err != nil {
return 0, Component{}, err
return 0, nil, err
}
offset += n

p := ProtocolWithCode(code)
if p.Code == 0 {
return 0, Component{}, fmt.Errorf("no protocol with code %d", code)
return 0, nil, fmt.Errorf("no protocol with code %d", code)
}
pPtr := protocolPtrByCode[code]
if pPtr == nil {
return 0, Component{}, fmt.Errorf("no protocol with code %d", code)
return 0, nil, fmt.Errorf("no protocol with code %d", code)
}

if p.Size == 0 {
c, err := validateComponent(Component{
c := &Component{
bytes: string(b[:offset]),
valueStartIdx: offset,
protocol: pPtr,
})
}

err := validateComponent(c)
if err != nil {
return 0, nil, err
}

return offset, c, err
return offset, c, nil
}

var size int
Expand All @@ -100,7 +105,7 @@ func readComponent(b []byte) (int, Component, error) {
var n int
size, n, err = ReadVarintCode(b[offset:])
if err != nil {
return 0, Component{}, err
return 0, nil, err
}
offset += n
} else {
Expand All @@ -109,14 +114,18 @@ func readComponent(b []byte) (int, Component, error) {
}

if len(b[offset:]) < size || size <= 0 {
return 0, Component{}, fmt.Errorf("invalid value for size %d", len(b[offset:]))
return 0, nil, fmt.Errorf("invalid value for size %d", len(b[offset:]))
}

c, err := validateComponent(Component{
c := &Component{
bytes: string(b[:offset+size]),
protocol: pPtr,
valueStartIdx: offset,
})
}
err = validateComponent(c)
if err != nil {
return 0, nil, err
}

return offset + size, c, err
}
Expand All @@ -142,7 +151,7 @@ func readMultiaddr(b []byte) (int, Multiaddr, error) {
return bytesRead, nil, fmt.Errorf("unexpected component after path component")
}
sawPathComponent = c.protocol.Path
res = append(res, c)
res = append(res, *c)
}
return bytesRead, res, nil
}
86 changes: 42 additions & 44 deletions component.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,19 @@ type Component struct {
valueStartIdx int // Index of the first byte of the Component's value in the bytes array
}

func (c *Component) AsMultiaddr() Multiaddr {
if c.Empty() {
func (c *Component) Multiaddr() Multiaddr {
if c == nil {
return nil
}
return []Component{*c}
}

func (c *Component) Encapsulate(o Multiaddr) Multiaddr {
return c.AsMultiaddr().Encapsulate(o)
}

func (c *Component) Decapsulate(o Multiaddr) Multiaddr {
return c.AsMultiaddr().Decapsulate(o)
func (c *Component) Encapsulate(o Multiaddrer) Multiaddr {
return c.Multiaddr().Encapsulate(o)
}

func (c *Component) Empty() bool {
if c == nil {
return true
}
return len(c.bytes) == 0
func (c *Component) Decapsulate(o Multiaddrer) Multiaddr {
return c.Multiaddr().Decapsulate(o)
}

func (c *Component) Bytes() []byte {
Expand All @@ -63,7 +56,7 @@ func (c *Component) UnmarshalBinary(data []byte) error {
if err != nil {
return err
}
*c = comp
*c = *comp
return nil
}

Expand All @@ -87,7 +80,7 @@ func (c *Component) UnmarshalText(data []byte) error {
if err != nil {
return err
}
*c = comp
*c = *comp
return nil
}

Expand Down Expand Up @@ -180,9 +173,6 @@ func (c *Component) Value() string {
if c == nil {
return ""
}
if c.Empty() {
return ""
}
// This Component MUST have been checked by validateComponent when created
value, _ := c.valueAndErr()
return value
Expand Down Expand Up @@ -236,24 +226,24 @@ func (c *Component) writeTo(b *strings.Builder) {
}

// NewComponent constructs a new multiaddr component
func NewComponent(protocol, value string) (Component, error) {
func NewComponent(protocol, value string) (*Component, error) {
p := ProtocolWithName(protocol)
if p.Code == 0 {
return Component{}, fmt.Errorf("unsupported protocol: %s", protocol)
return nil, fmt.Errorf("unsupported protocol: %s", protocol)
}
if p.Transcoder != nil {
bts, err := p.Transcoder.StringToBytes(value)
if err != nil {
return Component{}, err
return nil, err
}
return newComponent(p, bts)
} else if value != "" {
return Component{}, fmt.Errorf("protocol %s doesn't take a value", p.Name)
return nil, fmt.Errorf("protocol %s doesn't take a value", p.Name)
}
return newComponent(p, nil)
}

func newComponent(protocol Protocol, bvalue []byte) (Component, error) {
func newComponent(protocol Protocol, bvalue []byte) (*Component, error) {
protocolPtr := protocolPtrByCode[protocol.Code]
if protocolPtr == nil {
protocolPtr = &protocol
Expand All @@ -274,71 +264,79 @@ func newComponent(protocol Protocol, bvalue []byte) (Component, error) {

// Shouldn't happen
if len(maddr) != offset+len(bvalue) {
return Component{}, fmt.Errorf("component size mismatch: %d != %d", len(maddr), offset+len(bvalue))
return nil, fmt.Errorf("component size mismatch: %d != %d", len(maddr), offset+len(bvalue))
}

return validateComponent(
Component{
bytes: string(maddr),
protocol: protocolPtr,
valueStartIdx: offset,
})
c := &Component{
bytes: string(maddr),
protocol: protocolPtr,
valueStartIdx: offset,
}

err := validateComponent(c)
if err != nil {
return nil, err
}
return c, nil
}

// validateComponent MUST be called after creating a non-zero Component.
// It ensures that we will be able to call all methods on Component without
// error.
func validateComponent(c Component) (Component, error) {
func validateComponent(c *Component) error {
if c == nil {
return errNilPtr
}
if c.protocol == nil {
return Component{}, fmt.Errorf("component is missing its protocol")
return fmt.Errorf("component is missing its protocol")
}
if c.valueStartIdx > len(c.bytes) {
return Component{}, fmt.Errorf("component valueStartIdx is greater than the length of the component's bytes")
return fmt.Errorf("component valueStartIdx is greater than the length of the component's bytes")
}

if len(c.protocol.VCode) == 0 {
return Component{}, fmt.Errorf("Component is missing its protocol's VCode field")
return fmt.Errorf("Component is missing its protocol's VCode field")
}
if len(c.bytes) < len(c.protocol.VCode) {
return Component{}, fmt.Errorf("component size mismatch: %d != %d", len(c.bytes), len(c.protocol.VCode))
return fmt.Errorf("component size mismatch: %d != %d", len(c.bytes), len(c.protocol.VCode))
}
if !bytes.Equal([]byte(c.bytes[:len(c.protocol.VCode)]), c.protocol.VCode) {
return Component{}, fmt.Errorf("component's VCode field is invalid: %v != %v", []byte(c.bytes[:len(c.protocol.VCode)]), c.protocol.VCode)
return fmt.Errorf("component's VCode field is invalid: %v != %v", []byte(c.bytes[:len(c.protocol.VCode)]), c.protocol.VCode)
}
if c.protocol.Size < 0 {
size, n, err := ReadVarintCode([]byte(c.bytes[len(c.protocol.VCode):]))
if err != nil {
return Component{}, err
return err
}
if size != len(c.bytes[c.valueStartIdx:]) {
return Component{}, fmt.Errorf("component value size mismatch: %d != %d", size, len(c.bytes[c.valueStartIdx:]))
return fmt.Errorf("component value size mismatch: %d != %d", size, len(c.bytes[c.valueStartIdx:]))
}

if len(c.protocol.VCode)+n+size != len(c.bytes) {
return Component{}, fmt.Errorf("component size mismatch: %d != %d", len(c.protocol.VCode)+n+size, len(c.bytes))
return fmt.Errorf("component size mismatch: %d != %d", len(c.protocol.VCode)+n+size, len(c.bytes))
}
} else {
// Fixed size value
size := c.protocol.Size / 8
if size != len(c.bytes[c.valueStartIdx:]) {
return Component{}, fmt.Errorf("component value size mismatch: %d != %d", size, len(c.bytes[c.valueStartIdx:]))
return fmt.Errorf("component value size mismatch: %d != %d", size, len(c.bytes[c.valueStartIdx:]))
}

if len(c.protocol.VCode)+size != len(c.bytes) {
return Component{}, fmt.Errorf("component size mismatch: %d != %d", len(c.protocol.VCode)+size, len(c.bytes))
return fmt.Errorf("component size mismatch: %d != %d", len(c.protocol.VCode)+size, len(c.bytes))
}
}

_, err := c.valueAndErr()
if err != nil {
return Component{}, err
return err

}
if c.protocol.Transcoder != nil {
err = c.protocol.Transcoder.ValidateBytes([]byte(c.bytes[c.valueStartIdx:]))
if err != nil {
return Component{}, err
return err
}
}
return c, nil
return nil
}
49 changes: 26 additions & 23 deletions multiaddr.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,6 @@ func (m Multiaddr) copy() Multiaddr {
return out
}

func (m Multiaddr) Empty() bool {
if len(m) == 0 {
return true
}
for _, c := range m {
if !c.Empty() {
return false
}
}
return true
}

// NewMultiaddr parses and validates an input string, returning a *Multiaddr
func NewMultiaddr(s string) (a Multiaddr, err error) {
defer func() {
Expand Down Expand Up @@ -181,23 +169,38 @@ func (m Multiaddr) Protocols() []Protocol {
return out
}

// Encapsulate wraps a given Multiaddr, returning the resulting joined Multiaddr
func (m Multiaddr) Encapsulate(o Multiaddr) Multiaddr {
return Join(m, o)
type Multiaddrer interface {
// Multiaddr returns the Multiaddr representation
Multiaddr() Multiaddr
}

func (m Multiaddr) EncapsulateC(c *Component) Multiaddr {
if c.Empty() {
return m
func (m Multiaddr) Multiaddr() Multiaddr {
return m
}

// AppendComponent is the same as using `append(m, *c)`, but with a safety check
// for a nil Component.
func (m Multiaddr) AppendComponent(cs ...*Component) Multiaddr {
for _, c := range cs {
if c == nil {
continue
}
m = append(m, *c)
}
out := make([]Component, 0, len(m)+1)
out = append(out, m...)
out = append(out, *c)
return out
return m
}

// Encapsulate wraps a given Multiaddr, returning the resulting joined Multiaddr
func (m Multiaddr) Encapsulate(other Multiaddrer) Multiaddr {
return Join(m, other)
}

// Decapsulate unwraps Multiaddr up until the given Multiaddr is found.
func (m Multiaddr) Decapsulate(rightParts Multiaddr) Multiaddr {
func (m Multiaddr) Decapsulate(rightPartsAny Multiaddrer) Multiaddr {
if rightPartsAny == nil {
return m
}
rightParts := rightPartsAny.Multiaddr()
leftParts := m

lastIndex := -1
Expand Down
Loading

0 comments on commit 4d1f355

Please sign in to comment.