diff --git a/internal/core/services/claims.go b/internal/core/services/claims.go index e396108f4..df7c9c283 100644 --- a/internal/core/services/claims.go +++ b/internal/core/services/claims.go @@ -653,7 +653,7 @@ func (c *claim) getAgentCredential(ctx context.Context, basicMessage *ports.Agen return nil, fmt.Errorf("invalid credential fetch request body: %w", err) } - claimID, err := urn.Parse(urn.URN(fetchRequestBody.ID)) + claimID, err := urn.UUIDFromURNString(fetchRequestBody.ID) if err != nil { claimID, err = uuid.Parse(fetchRequestBody.ID) if err != nil { diff --git a/internal/urn/urn.go b/internal/urn/urn.go index 0f66b1b9a..b6fd782fd 100644 --- a/internal/urn/urn.go +++ b/internal/urn/urn.go @@ -14,13 +14,37 @@ func FromUUID(uuid uuid.UUID) URN { return URN("urn:uuid:" + uuid.String()) } -// Parse extracts a UUID from a URN. -func Parse(u URN) (uuid.UUID, error) { +// UUID returns the UUID from a URN. It can throw an error to prevent bad constructor calls or urns without uuids +func (u URN) UUID() (uuid.UUID, error) { + if err := u.valid(); err != nil { + return uuid.Nil, err + } + return uuid.Parse(string(u[9:])) +} + +func (u URN) valid() error { if len(u) < len("urn:uuid:") { - return uuid.UUID{}, errors.New("invalid uuid URN length") + return errors.New("invalid uuid URN length") } if u[:9] != "urn:uuid:" { - return uuid.UUID{}, errors.New("invalid uuid URN prefix") + return errors.New("invalid uuid URN prefix") } - return uuid.Parse(string(u[9:])) + return nil +} + +// Parse creates a URN from a string. +func Parse(u string) (URN, error) { + if err := URN(u).valid(); err != nil { + return "", err + } + return URN(u), nil +} + +// UUIDFromURNString returns the UUID from a URN string. +func UUIDFromURNString(s string) (uuid.UUID, error) { + urn, err := Parse(s) + if err != nil { + return uuid.Nil, err + } + return urn.UUID() } diff --git a/internal/urn/urn_test.go b/internal/urn/urn_test.go index bac33d58a..db662452d 100644 --- a/internal/urn/urn_test.go +++ b/internal/urn/urn_test.go @@ -14,7 +14,7 @@ func TestFromUUID(t *testing.T) { assert.Equal(t, "urn:uuid:"+id.String(), string(urn)) } -func TestParse(t *testing.T) { +func TestUUIDFromURNString(t *testing.T) { for _, ts := range []struct { urn string err error @@ -25,7 +25,7 @@ func TestParse(t *testing.T) { {"urn:uuid:123e4567-e89b-12d3-a456-426614174000", nil}, } { t.Run(ts.urn, func(t *testing.T) { - u, err := Parse(URN(ts.urn)) + u, err := UUIDFromURNString(ts.urn) if err == nil { assert.Equal(t, ts.urn[9:], u.String()) } else {