diff --git a/uasc/secure_channel_test.go b/uasc/secure_channel_test.go index 91fe451b..3918a34f 100644 --- a/uasc/secure_channel_test.go +++ b/uasc/secure_channel_test.go @@ -1,13 +1,11 @@ package uasc import ( - "bytes" "crypto/rsa" "crypto/x509" "encoding/pem" "fmt" "math" - "strings" "testing" "time" @@ -146,9 +144,7 @@ func TestNewRequestMessage(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { m, err := tt.sechan.activeInstance.newRequestMessage(tt.req, tt.sechan.nextRequestID(), tt.authToken, tt.timeout) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) require.Equal(t, tt.m, m) }) } @@ -159,21 +155,15 @@ func TestSignAndEncryptVerifyAndDecrypt(t *testing.T) { t.Helper() certPEM, keyPEM, err := uatest.GenerateCert("localhost", bits, 24*time.Hour) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) block, _ := pem.Decode(keyPEM) pk, err := x509.ParsePKCS1PrivateKey(block.Bytes) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) certblock, _ := pem.Decode(certPEM) remoteX509Cert, err := x509.ParseCertificate(certblock.Bytes) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) remoteKey := remoteX509Cert.PublicKey.(*rsa.PublicKey) alg, _ := uapolicy.Asymmetric(uri, pk, remoteKey) @@ -293,23 +283,17 @@ func TestSignAndEncryptVerifyAndDecrypt(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cipher, err := tt.c.signAndEncrypt(tt.m, tt.b) - if err != nil { - t.Fatalf("error: message encrypt: %v", err) - } + require.NoError(t, err, "error: message encrypt") m := new(MessageChunk) - if _, err := m.Decode(cipher); err != nil { - t.Fatalf("error: message decode: %v", err) - } + _, err = m.Decode(cipher) + require.NoError(t, err, "error: message decode") + plain, err := tt.c.verifyAndDecrypt(m, cipher) - if err != nil { - t.Fatalf("error: message decrypt: %v", err) - } + require.NoError(t, err, "error: message decrypt") headerLength := 12 + m.AsymmetricSecurityHeader.Len() - if got, want := plain, tt.b[headerLength:]; !bytes.Equal(got, want) { - t.Fatalf("got bytes %v want %v", got, want) - } + require.Equal(t, tt.b[headerLength:], plain, "header not equal") }) } } @@ -317,39 +301,29 @@ func TestSignAndEncryptVerifyAndDecrypt(t *testing.T) { func TestNewSecureChannel(t *testing.T) { t.Run("no connection", func(t *testing.T) { _, err := NewSecureChannel("", nil, nil, nil) - errorContains(t, err, "no connection") + require.ErrorContains(t, err, "no connection") }) t.Run("no error channel", func(t *testing.T) { _, err := NewSecureChannel("", &uacp.Conn{}, nil, nil) - errorContains(t, err, "no secure channel config") + require.ErrorContains(t, err, "no secure channel config") }) t.Run("no config", func(t *testing.T) { _, err := NewSecureChannel("", &uacp.Conn{}, nil, make(chan error)) - errorContains(t, err, "no secure channel config") + require.ErrorContains(t, err, "no secure channel config") }) t.Run("uri none, mode not none", func(t *testing.T) { cfg := &Config{SecurityPolicyURI: ua.SecurityPolicyURINone, SecurityMode: ua.MessageSecurityModeSign} _, err := NewSecureChannel("", &uacp.Conn{}, cfg, make(chan error)) - errorContains(t, err, "invalid channel config: Security policy 'http://opcfoundation.org/UA/SecurityPolicy#None' cannot be used with 'MessageSecurityModeSign'") + require.ErrorContains(t, err, "invalid channel config: Security policy 'http://opcfoundation.org/UA/SecurityPolicy#None' cannot be used with 'MessageSecurityModeSign'") }) t.Run("uri not none, mode none", func(t *testing.T) { cfg := &Config{SecurityPolicyURI: ua.SecurityPolicyURIBasic256, SecurityMode: ua.MessageSecurityModeNone} _, err := NewSecureChannel("", &uacp.Conn{}, cfg, make(chan error)) - errorContains(t, err, "Security policy 'http://opcfoundation.org/UA/SecurityPolicy#Basic256' cannot be used with 'MessageSecurityModeNone'") + require.ErrorContains(t, err, "Security policy 'http://opcfoundation.org/UA/SecurityPolicy#Basic256' cannot be used with 'MessageSecurityModeNone'") }) t.Run("uri not none, local key missing", func(t *testing.T) { cfg := &Config{SecurityPolicyURI: ua.SecurityPolicyURIBasic256, SecurityMode: ua.MessageSecurityModeSign} _, err := NewSecureChannel("", &uacp.Conn{}, cfg, make(chan error)) - errorContains(t, err, "invalid channel config: Security policy 'http://opcfoundation.org/UA/SecurityPolicy#Basic256' requires a private key") + require.ErrorContains(t, err, "invalid channel config: Security policy 'http://opcfoundation.org/UA/SecurityPolicy#Basic256' requires a private key") }) } - -func errorContains(t *testing.T, err error, msg string) { - t.Helper() - if err == nil { - t.Fatal("expected an error but got nil") - } - if !strings.Contains(err.Error(), msg) { - t.Fatalf("error '%s' does not contain '%s'", err, msg) - } -}