Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix races to c.Session() and c.SecureChannel() #654

Merged
merged 2 commits into from
May 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 23 additions & 18 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@ func (c *Client) Connect(ctx context.Context) error {
return c.cfgerr
}

// todo(fs): the secure channel is 'nil' during a re-connect
// todo(fs): but we expect this method to be called once during startup
// todo(fs): so this is probably safe
if c.SecureChannel() != nil {
return errors.Errorf("already connected")
}
Expand Down Expand Up @@ -605,8 +608,8 @@ func (c *Client) CloseWithContext(ctx context.Context) error {
if c.mcancel != nil {
c.mcancel()
}
if c.SecureChannel() != nil {
c.SecureChannel().Close()
if sc := c.SecureChannel(); sc != nil {
sc.Close()
c.setSecureChannel(nil)
}

Expand Down Expand Up @@ -657,6 +660,8 @@ func (c *Client) setPublishTimeout(d time.Duration) {
}

// SecureChannel returns the active secure channel.
// During reconnect this value can change.
// Make sure to capture the value in a method before using it.
func (c *Client) SecureChannel() *uasc.SecureChannel {
return c.atomicSechan.Load().(*uasc.SecureChannel)
}
Expand All @@ -667,6 +672,8 @@ func (c *Client) setSecureChannel(sc *uasc.SecureChannel) {
}

// Session returns the active session.
// During reconnect this value can change.
// Make sure to capture the value in a method before using it.
func (c *Client) Session() *Session {
return c.atomicSession.Load().(*Session)
}
Expand All @@ -676,11 +683,6 @@ func (c *Client) setSession(s *Session) {
stats.Client().Add("Session", 1)
}

// sessionClosed returns true when there is no session.
func (c *Client) sessionClosed() bool {
return c.Session() == nil
}

// Session is a OPC/UA session as described in Part 4, 5.6.
type Session struct {
cfg *uasc.SessionConfig
Expand Down Expand Up @@ -727,7 +729,8 @@ func (c *Client) CreateSession(cfg *uasc.SessionConfig) (*Session, error) {

// Note: Starting with v0.5 this method is superseded by the non 'WithContext' method.
func (c *Client) CreateSessionWithContext(ctx context.Context, cfg *uasc.SessionConfig) (*Session, error) {
if c.SecureChannel() == nil {
sc := c.SecureChannel()
if sc == nil {
return nil, ua.StatusBadServerNotConnected
}

Expand All @@ -752,14 +755,14 @@ func (c *Client) CreateSessionWithContext(ctx context.Context, cfg *uasc.Session

var s *Session
// for the CreateSessionRequest the authToken is always nil.
// use c.SecureChannel().SendRequest() to enforce this.
err := c.SecureChannel().SendRequestWithContext(ctx, req, nil, func(v interface{}) error {
// use sc.SendRequest() to enforce this.
err := sc.SendRequestWithContext(ctx, req, nil, func(v interface{}) error {
var res *ua.CreateSessionResponse
if err := safeAssign(v, &res); err != nil {
return err
}

err := c.SecureChannel().VerifySessionSignature(res.ServerCertificate, nonce, res.ServerSignature.Signature)
err := sc.VerifySessionSignature(res.ServerCertificate, nonce, res.ServerSignature.Signature)
if err != nil {
log.Printf("error verifying session signature: %s", err)
return nil
Expand Down Expand Up @@ -820,11 +823,12 @@ func (c *Client) ActivateSession(s *Session) error {

// Note: Starting with v0.5 this method is superseded by the non 'WithContext' method.
func (c *Client) ActivateSessionWithContext(ctx context.Context, s *Session) error {
if c.SecureChannel() == nil {
sc := c.SecureChannel()
if sc == nil {
return ua.StatusBadServerNotConnected
}
stats.Client().Add("ActivateSession", 1)
sig, sigAlg, err := c.SecureChannel().NewSessionSignature(s.serverCertificate, s.serverNonce)
sig, sigAlg, err := sc.NewSessionSignature(s.serverCertificate, s.serverNonce)
if err != nil {
log.Printf("error creating session signature: %s", err)
return nil
Expand All @@ -835,7 +839,7 @@ func (c *Client) ActivateSessionWithContext(ctx context.Context, s *Session) err
// nothing to do

case *ua.UserNameIdentityToken:
pass, passAlg, err := c.SecureChannel().EncryptUserPassword(s.cfg.AuthPolicyURI, s.cfg.AuthPassword, s.serverCertificate, s.serverNonce)
pass, passAlg, err := sc.EncryptUserPassword(s.cfg.AuthPolicyURI, s.cfg.AuthPassword, s.serverCertificate, s.serverNonce)
if err != nil {
log.Printf("error encrypting user password: %s", err)
return err
Expand All @@ -844,7 +848,7 @@ func (c *Client) ActivateSessionWithContext(ctx context.Context, s *Session) err
tok.EncryptionAlgorithm = passAlg

case *ua.X509IdentityToken:
tokSig, tokSigAlg, err := c.SecureChannel().NewUserTokenSignature(s.cfg.AuthPolicyURI, s.serverCertificate, s.serverNonce)
tokSig, tokSigAlg, err := sc.NewUserTokenSignature(s.cfg.AuthPolicyURI, s.serverCertificate, s.serverNonce)
if err != nil {
log.Printf("error creating session signature: %s", err)
return err
Expand All @@ -868,7 +872,7 @@ func (c *Client) ActivateSessionWithContext(ctx context.Context, s *Session) err
UserIdentityToken: ua.NewExtensionObject(s.cfg.UserIdentityToken),
UserTokenSignature: s.cfg.UserTokenSignature,
}
return c.SecureChannel().SendRequestWithContext(ctx, req, s.resp.AuthenticationToken, func(v interface{}) error {
return sc.SendRequestWithContext(ctx, req, s.resp.AuthenticationToken, func(v interface{}) error {
var res *ua.ActivateSessionResponse
if err := safeAssign(v, &res); err != nil {
return err
Expand Down Expand Up @@ -965,14 +969,15 @@ func (c *Client) SendWithContext(ctx context.Context, req ua.Request, h func(int
// the response. If the client has an active session it injects the
// authentication token.
func (c *Client) sendWithTimeout(ctx context.Context, req ua.Request, timeout time.Duration, h func(interface{}) error) error {
if c.SecureChannel() == nil {
sc := c.SecureChannel()
if sc == nil {
return ua.StatusBadServerNotConnected
}
var authToken *ua.NodeID
if s := c.Session(); s != nil {
authToken = s.resp.AuthenticationToken
}
return c.SecureChannel().SendRequestWithTimeoutWithContext(ctx, req, authToken, timeout, h)
return sc.SendRequestWithTimeoutWithContext(ctx, req, authToken, timeout, h)
}

// Node returns a node object which accesses its attributes
Expand Down
11 changes: 9 additions & 2 deletions client_sub.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,21 @@ func (c *Client) sendRepublishRequests(ctx context.Context, sub *Subscription, a
req.RetransmitSequenceNumber,
)

if c.sessionClosed() {
s := c.Session()
if s == nil {
debug.Printf("Republishing subscription %d aborted", req.SubscriptionID)
return ua.StatusBadSessionClosed
}

sc := c.SecureChannel()
if sc == nil {
debug.Printf("Republishing subscription %d aborted", req.SubscriptionID)
return ua.StatusBadNotConnected
}

debug.Printf("RepublishRequest: req=%s", debug.ToJSON(req))
var res *ua.RepublishResponse
err := c.SecureChannel().SendRequestWithContext(ctx, req, c.Session().resp.AuthenticationToken, func(v interface{}) error {
err := sc.SendRequestWithContext(ctx, req, s.resp.AuthenticationToken, func(v interface{}) error {
return safeAssign(v, &res)
})
debug.Printf("RepublishResponse: res=%s err=%v", debug.ToJSON(res), err)
Expand Down