diff --git a/client.go b/client.go index 3d3ef2ed..61227f10 100644 --- a/client.go +++ b/client.go @@ -268,44 +268,34 @@ func (c *Client) monitor(ctx context.Context) { dlog.Print("auto-reconnecting") - switch err { - case io.EOF: + switch { + case errors.Is(err, io.EOF): // the connection has been closed action = createSecureChannel - case syscall.ECONNREFUSED: + case errors.Is(err, syscall.ECONNREFUSED): // the connection has been refused by the server action = abortReconnect - default: - switch x := err.(type) { - case *uacp.Error: - switch ua.StatusCode(x.ErrorCode) { - case ua.StatusBadSecureChannelIDInvalid: - // the secure channel has been rejected by the server - action = createSecureChannel - - case ua.StatusBadSessionIDInvalid: - // the session has been rejected by the server - action = recreateSession - - case ua.StatusBadSubscriptionIDInvalid: - // the subscription has been rejected by the server - action = transferSubscriptions + case errors.Is(err, ua.StatusBadSecureChannelIDInvalid): + // the secure channel has been rejected by the server + action = createSecureChannel - case ua.StatusBadCertificateInvalid: - // todo(unknownet): recreate server certificate - fallthrough + case errors.Is(err, ua.StatusBadSessionIDInvalid): + // the session has been rejected by the server + action = recreateSession - default: - // unknown error has occured - action = createSecureChannel - } + case errors.Is(err, ua.StatusBadSubscriptionIDInvalid): + // the subscription has been rejected by the server + action = transferSubscriptions - default: - // unknown error has occured - action = createSecureChannel - } + case errors.Is(err, ua.StatusBadCertificateInvalid): + // todo(unknownet): recreate server certificate + fallthrough + + default: + // unknown error has occured + action = createSecureChannel } c.setState(Disconnected) diff --git a/client_sub.go b/client_sub.go index 2f311e3a..0da4fecd 100644 --- a/client_sub.go +++ b/client_sub.go @@ -135,18 +135,15 @@ func (c *Client) republishSubscription(ctx context.Context, id uint32, available debug.Printf("republishing subscription %d", sub.SubscriptionID) if err := c.sendRepublishRequests(ctx, sub, availableSeq); err != nil { - status, ok := err.(ua.StatusCode) - if !ok { - return err - } - - switch status { - case ua.StatusBadSessionIDInvalid: + switch { + case errors.Is(err, ua.StatusBadSessionIDInvalid): return nil - case ua.StatusBadSubscriptionIDInvalid: + case errors.Is(err, ua.StatusBadSubscriptionIDInvalid): // todo(fs): do we need to forget the subscription id in this case? debug.Printf("republish failed since subscription %d is invalid", sub.SubscriptionID) return errors.Errorf("republish failed since subscription %d is invalid", sub.SubscriptionID) + default: + return err } } return nil diff --git a/errors/errors.go b/errors/errors.go index cf6b0467..bfc84dac 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -1,22 +1,39 @@ package errors import ( + "errors" + pkg_errors "github.com/pkg/errors" ) // Prefix is the default error string prefix const Prefix = "opcua: " -// Errorf is a wrapper for `errors.Errorf` +// Errorf wraps github.com/pig/errors#Errorf` func Errorf(format string, a ...interface{}) error { return pkg_errors.Errorf(Prefix+format, a...) } -// New is a wrapper for `errors.New` +// New wraps github.com/pkg/errors#New func New(text string) error { return pkg_errors.New(Prefix + text) } +// Is wraps errors.Is +func Is(err error, target error) bool { + return errors.Is(err, target) +} + +// As wraps errors.As +func As(err error, target interface{}) bool { + return errors.As(err, target) +} + +// Unwrap wraps errors.Unwrap +func Unwrap(err error) error { + return errors.Unwrap(err) +} + // Equal returns true if the two errors have the same error message. // // todo(fs): the reason we need this function and cannot just use diff --git a/stats/stats.go b/stats/stats.go index 549fad1c..25c97786 100644 --- a/stats/stats.go +++ b/stats/stats.go @@ -4,6 +4,7 @@ package stats import ( + "errors" "expvar" "io" "reflect" @@ -45,22 +46,20 @@ func (s *Stats) RecordError(err error) { if err == nil { return } - switch err { - case io.EOF: + var code ua.StatusCode + switch { + case errors.Is(err, io.EOF): s.Error.Add("io.EOF", 1) - case ua.StatusOK: + case errors.Is(err, ua.StatusOK): s.Error.Add("ua.StatusOK", 1) - case ua.StatusBad: + case errors.Is(err, ua.StatusBad): s.Error.Add("ua.StatusBad", 1) - case ua.StatusUncertain: + case errors.Is(err, ua.StatusUncertain): s.Error.Add("ua.StatusUncertain", 1) + case errors.As(err, &code): + s.Error.Add("ua."+ua.StatusCodes[code].Name, 1) default: - switch x := err.(type) { - case ua.StatusCode: - s.Error.Add("ua."+ua.StatusCodes[x].Name, 1) - default: - s.Error.Add(reflect.TypeOf(err).String(), 1) - } + s.Error.Add(reflect.TypeOf(err).String(), 1) } } diff --git a/uacp/conn_test.go b/uacp/conn_test.go index decd211a..2269bbe2 100644 --- a/uacp/conn_test.go +++ b/uacp/conn_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/gopcua/opcua/errors" "github.com/pascaldekloe/goe/verify" ) @@ -56,7 +57,8 @@ func TestConn(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() _, err := Dial(ctx, ep) - if !err.(*net.OpError).Timeout() { + var operr *net.OpError + if errors.As(err, &operr) && !operr.Timeout() { t.Error(err) } }) diff --git a/uacp/endpoint.go b/uacp/endpoint.go index b1b58e98..4df721d1 100644 --- a/uacp/endpoint.go +++ b/uacp/endpoint.go @@ -27,8 +27,8 @@ func ResolveEndpoint(endpoint string) (network string, addr *net.TCPAddr, err er network = "tcp" addr, err = net.ResolveTCPAddr(network, addrString) - switch err.(type) { - case *net.DNSError: + var dnserr *net.DNSError + if errors.As(err, &dnserr) { return "", nil, errors.Errorf("could not resolve address %s", addrString) } return diff --git a/uacp/uacp.go b/uacp/uacp.go index cf36c958..7aa37b91 100644 --- a/uacp/uacp.go +++ b/uacp/uacp.go @@ -169,6 +169,11 @@ func (e *Error) Error() string { return ua.StatusCode(e.ErrorCode).Error() } +// Unwrap returns the wrapped error code. +func (e *Error) Unwrap() error { + return ua.StatusCode(e.ErrorCode) +} + type Message struct { Data []byte } diff --git a/uasc/secure_channel.go b/uasc/secure_channel.go index 960a5c39..4ee74d4f 100644 --- a/uasc/secure_channel.go +++ b/uasc/secure_channel.go @@ -344,7 +344,8 @@ func (s *SecureChannel) readChunk() (*MessageChunk, error) { return nil, io.EOF } // do not wrap this error since it hides conn error - if _, ok := err.(*uacp.Error); ok { + var uacperr *uacp.Error + if errors.As(err, &uacperr) { return nil, err } if err != nil { diff --git a/uatest/timeout_test.go b/uatest/timeout_test.go index 0d8507fa..0d29eba5 100644 --- a/uatest/timeout_test.go +++ b/uatest/timeout_test.go @@ -10,6 +10,7 @@ import ( "time" "github.com/gopcua/opcua" + "github.com/gopcua/opcua/errors" ) const ( @@ -44,12 +45,14 @@ func connectAndValidate(t *testing.T, c *opcua.Client, ctx context.Context, d ti elapsed := time.Since(start) - if oe, ok := err.(*net.OpError); ok { - if !oe.Timeout() { - t.Fatalf("got %#v, wanted net.timeoutError", oe.Unwrap()) - } - } else { - t.Fatalf("got %T, wanted %T", err, net.OpError{}) + var oe *net.OpError + switch { + case errors.As(err, &oe) && !oe.Timeout(): + t.Fatalf("got %#v, wanted net.timeoutError", oe.Unwrap()) + case errors.As(err, &oe): + // ignore + default: + t.Fatalf("got %T, wanted %T", err, &net.OpError{}) } pct := 0.05