Skip to content

Commit 12a6af7

Browse files
authored
PYTHON-2981 Stop using MongoClient.address for hashing and equality (#795)
1 parent 24cc4c4 commit 12a6af7

File tree

6 files changed

+57
-14
lines changed

6 files changed

+57
-14
lines changed

doc/changelog.rst

+3
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,9 @@ Breaking Changes in 4.0
177177
:exc:`~pymongo.errors.InvalidURI` exception
178178
when it encounters unescaped percent signs in username and password when
179179
parsing MongoDB URIs.
180+
- Comparing two :class:`~pymongo.mongo_client.MongoClient` instances now
181+
uses a set of immutable properties rather than
182+
:attr:`~pymongo.mongo_client.MongoClient.address` which can change.
180183
- Removed the `disable_md5` parameter for :class:`~gridfs.GridFSBucket` and
181184
:class:`~gridfs.GridFS`. See :ref:`removed-gridfs-checksum` for details.
182185
- PyMongoCrypt 1.2.0 or later is now required for client side field level

pymongo/mongo_client.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -748,9 +748,9 @@ def __init__(
748748
server_selector=options.server_selector,
749749
heartbeat_frequency=options.heartbeat_frequency,
750750
fqdn=fqdn,
751-
srv_service_name=srv_service_name,
752751
direct_connection=options.direct_connection,
753752
load_balanced=options.load_balanced,
753+
srv_service_name=srv_service_name,
754754
srv_max_hosts=srv_max_hosts
755755
)
756756

@@ -1337,14 +1337,14 @@ def _retryable_write(self, retryable, func, session):
13371337

13381338
def __eq__(self, other):
13391339
if isinstance(other, self.__class__):
1340-
return self.address == other.address
1340+
return self._topology == other._topology
13411341
return NotImplemented
13421342

13431343
def __ne__(self, other):
13441344
return not self == other
13451345

13461346
def __hash__(self):
1347-
return hash(self.address)
1347+
return hash(self._topology)
13481348

13491349
def _repr_helper(self):
13501350
def option_repr(option, value):

pymongo/monitor.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,6 @@ def __init__(self, topology, topology_settings):
299299
self._settings = topology_settings
300300
self._seedlist = self._settings._seeds
301301
self._fqdn = self._settings.fqdn
302-
self._srv_service_name = self._settings._srv_service_name
303302

304303
def _run(self):
305304
seedlist = self._get_seedlist()
@@ -319,7 +318,7 @@ def _get_seedlist(self):
319318
try:
320319
resolver = _SrvResolver(self._fqdn,
321320
self._settings.pool_options.connect_timeout,
322-
self._srv_service_name)
321+
self._settings.srv_service_name)
323322
seedlist, ttl = resolver.get_hosts_and_min_ttl()
324323
if len(seedlist) == 0:
325324
# As per the spec: this should be treated as a failure.

pymongo/settings.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ def __init__(self,
3939
heartbeat_frequency=common.HEARTBEAT_FREQUENCY,
4040
server_selector=None,
4141
fqdn=None,
42-
srv_service_name=common.SRV_SERVICE_NAME,
4342
direct_connection=False,
4443
load_balanced=None,
44+
srv_service_name=common.SRV_SERVICE_NAME,
4545
srv_max_hosts=0):
4646
"""Represent MongoClient's configuration.
4747
@@ -62,11 +62,11 @@ def __init__(self,
6262
self._server_selection_timeout = server_selection_timeout
6363
self._server_selector = server_selector
6464
self._fqdn = fqdn
65-
self._srv_service_name = srv_service_name
6665
self._heartbeat_frequency = heartbeat_frequency
67-
self._srv_max_hosts = srv_max_hosts or 0
6866
self._direct = direct_connection
6967
self._load_balanced = load_balanced
68+
self._srv_service_name = srv_service_name
69+
self._srv_max_hosts = srv_max_hosts or 0
7070

7171
self._topology_id = ObjectId()
7272
# Store the allocation traceback to catch unclosed clients in the
@@ -131,6 +131,16 @@ def load_balanced(self):
131131
"""True if the client was configured to connect to a load balancer."""
132132
return self._load_balanced
133133

134+
@property
135+
def srv_service_name(self):
136+
"""The srvServiceName."""
137+
return self._srv_service_name
138+
139+
@property
140+
def srv_max_hosts(self):
141+
"""The srvMaxHosts."""
142+
return self._srv_max_hosts
143+
134144
def get_topology_type(self):
135145
if self.load_balanced:
136146
return TOPOLOGY_TYPE.LoadBalanced

pymongo/topology.py

+14
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,20 @@ def __repr__(self):
803803
msg = 'CLOSED '
804804
return '<%s %s%r>' % (self.__class__.__name__, msg, self._description)
805805

806+
def eq_props(self):
807+
"""The properties to use for MongoClient/Topology equality checks."""
808+
ts = self._settings
809+
return (tuple(sorted(ts.seeds)), ts.replica_set_name, ts.fqdn,
810+
ts.srv_service_name)
811+
812+
def __eq__(self, other):
813+
if isinstance(other, self.__class__):
814+
return self.eq_props() == other.eq_props()
815+
return NotImplemented
816+
817+
def __hash__(self):
818+
return hash(self.eq_props())
819+
806820

807821
class _ErrorContext(object):
808822
"""An error with context for SDAM error handling."""

test/test_client.py

+23-6
Original file line numberDiff line numberDiff line change
@@ -676,15 +676,32 @@ def test_init_disconnected_with_auth(self):
676676
self.assertRaises(ConnectionFailure, c.pymongo_test.test.find_one)
677677

678678
def test_equality(self):
679-
c = connected(rs_or_single_client())
679+
seed = '%s:%s' % list(self.client._topology_settings.seeds)[0]
680+
c = rs_or_single_client(seed, connect=False)
681+
self.addCleanup(c.close)
680682
self.assertEqual(client_context.client, c)
681-
682683
# Explicitly test inequality
683684
self.assertFalse(client_context.client != c)
684685

686+
c = rs_or_single_client('invalid.com', connect=False)
687+
self.addCleanup(c.close)
688+
self.assertNotEqual(client_context.client, c)
689+
self.assertTrue(client_context.client != c)
690+
# Seeds differ:
691+
self.assertNotEqual(MongoClient('a', connect=False),
692+
MongoClient('b', connect=False))
693+
# Same seeds but out of order still compares equal:
694+
self.assertEqual(MongoClient(['a', 'b', 'c'], connect=False),
695+
MongoClient(['c', 'a', 'b'], connect=False))
696+
685697
def test_hashable(self):
686-
c = connected(rs_or_single_client())
698+
seed = '%s:%s' % list(self.client._topology_settings.seeds)[0]
699+
c = rs_or_single_client(seed, connect=False)
700+
self.addCleanup(c.close)
687701
self.assertIn(c, {client_context.client})
702+
c = rs_or_single_client('invalid.com', connect=False)
703+
self.addCleanup(c.close)
704+
self.assertNotIn(c, {client_context.client})
688705

689706
def test_host_w_port(self):
690707
with self.assertRaises(ValueError):
@@ -1635,19 +1652,19 @@ def test_service_name_from_kwargs(self):
16351652
client = MongoClient(
16361653
'mongodb+srv://user:password@test22.test.build.10gen.cc',
16371654
srvServiceName='customname', connect=False)
1638-
self.assertEqual(client._topology_settings._srv_service_name,
1655+
self.assertEqual(client._topology_settings.srv_service_name,
16391656
'customname')
16401657
client = MongoClient(
16411658
'mongodb+srv://user:password@test22.test.build.10gen.cc'
16421659
'/?srvServiceName=shouldbeoverriden',
16431660
srvServiceName='customname', connect=False)
1644-
self.assertEqual(client._topology_settings._srv_service_name,
1661+
self.assertEqual(client._topology_settings.srv_service_name,
16451662
'customname')
16461663
client = MongoClient(
16471664
'mongodb+srv://user:password@test22.test.build.10gen.cc'
16481665
'/?srvServiceName=customname',
16491666
connect=False)
1650-
self.assertEqual(client._topology_settings._srv_service_name,
1667+
self.assertEqual(client._topology_settings.srv_service_name,
16511668
'customname')
16521669

16531670
@unittest.skipUnless(

0 commit comments

Comments
 (0)