Skip to content

Commit cb84e6a

Browse files
committed
Use the same SSH connection for main thread and completer thread
1 parent 7a059a6 commit cb84e6a

File tree

8 files changed

+115
-100
lines changed

8 files changed

+115
-100
lines changed

changelog.md

+5-3
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@ TBD
33

44
Features:
55
---------
6-
* Add an option `--ssh-config-host` to read ssh configuration from OpenSSH configuration file.
7-
* Add an option `--list-ssh-config` to list ssh configurations.
8-
* Add an option `--ssh-config-path` to choose ssh configuration path.
6+
* Add an option `--ssh-config-host` to read ssh configuration from OpenSSH configuration file (Thanks: [Nathan Huang]).
7+
* Add an option `--list-ssh-config` to list ssh configurations (Thanks: [Nathan Huang]).
8+
* Add an option `--ssh-config-path` to choose ssh configuration path (Thanks: [Nathan Huang]).
9+
* Reuse the same SSH connection in both main thread and completion thread (Thanks: [Georgy Frolov]).
910

1011

1112
1.21.1
@@ -757,3 +758,4 @@ Bug Fixes:
757758
[François Pietka]: https://github.com/fpietka
758759
[Frederic Aoustin]: https://github.com/fraoustin
759760
[Georgy Frolov]: https://github.com/pasenor
761+
[Nathan Huang]: https://github.com/hxueh

mycli/completion_refresher.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@ def _bg_refresh(self, sqlexecute, callbacks, completer_options):
5151
e = sqlexecute
5252
executor = SQLExecute(e.dbname, e.user, e.password, e.host, e.port,
5353
e.socket, e.charset, e.local_infile, e.ssl,
54-
e.ssh_user, e.ssh_host, e.ssh_port,
55-
e.ssh_password, e.ssh_key_filename)
54+
e.ssh_client)
5655

5756
# If callbacks is a single function then push it into a list.
5857
if callable(callbacks):

mycli/main.py

+18-46
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,11 @@
3131
from prompt_toolkit.history import FileHistory
3232
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
3333

34+
from mycli.packages.ssh_client import create_ssh_client
3435
from .packages.special.main import NO_QUERY
3536
from .packages.prompt_utils import confirm, confirm_destructive_query
3637
from .packages.tabular_output import sql_format
37-
from .packages import special
38+
from .packages import special, ssh_client
3839
from .packages.special.favoritequeries import FavoriteQueries
3940
from .sqlcompleter import SQLCompleter
4041
from .clitoolbar import create_toolbar_tokens_func
@@ -63,11 +64,6 @@
6364
from urllib.parse import unquote
6465

6566

66-
try:
67-
import paramiko
68-
except ImportError:
69-
from mycli.packages.paramiko_stub import paramiko
70-
7167
# Query tuples are used for maintaining history
7268
Query = namedtuple('Query', ['query', 'successful', 'mutating'])
7369

@@ -198,6 +194,8 @@ def __init__(self, sqlexecute=None, prompt=None,
198194

199195
self.prompt_app = None
200196

197+
self.ssh_client = None
198+
201199
def register_special_commands(self):
202200
special.register_special_command(self.change_db, 'use',
203201
'\\u', 'Change to a new database.', aliases=('\\u',))
@@ -358,9 +356,7 @@ def merge_ssl_with_cnf(self, ssl, cnf):
358356
return merged
359357

360358
def connect(self, database='', user='', passwd='', host='', port='',
361-
socket='', charset='', local_infile='', ssl='',
362-
ssh_user='', ssh_host='', ssh_port='',
363-
ssh_password='', ssh_key_filename=''):
359+
socket='', charset='', local_infile='', ssl=None):
364360

365361
cnf = {'database': None,
366362
'user': None,
@@ -384,7 +380,7 @@ def connect(self, database='', user='', passwd='', host='', port='',
384380

385381
database = database or cnf['database']
386382
# Socket interface not supported for SSH connections
387-
if port or host or ssh_host or ssh_port:
383+
if port or host or self.ssh_client:
388384
socket = ''
389385
else:
390386
socket = socket or cnf['socket'] or guess_socket_location()
@@ -416,17 +412,15 @@ def _connect():
416412
try:
417413
self.sqlexecute = SQLExecute(
418414
database, user, passwd, host, port, socket, charset,
419-
local_infile, ssl, ssh_user, ssh_host, ssh_port,
420-
ssh_password, ssh_key_filename
415+
local_infile, ssl, ssh_client=self.ssh_client
421416
)
422417
except OperationalError as e:
423418
if ('Access denied for user' in e.args[1]):
424419
new_passwd = click.prompt('Password', hide_input=True,
425420
show_default=False, type=str, err=True)
426421
self.sqlexecute = SQLExecute(
427422
database, user, new_passwd, host, port, socket,
428-
charset, local_infile, ssl, ssh_user, ssh_host,
429-
ssh_port, ssh_password, ssh_key_filename
423+
charset, local_infile, ssl, ssh_client=self.ssh_client
430424
)
431425
else:
432426
raise e
@@ -1092,16 +1086,17 @@ def cli(database, user, host, port, socket, password, dbname,
10921086
else:
10931087
click.secho(alias)
10941088
sys.exit(0)
1089+
10951090
if list_ssh_config:
1096-
ssh_config = read_ssh_config(ssh_config_path)
1097-
for host in ssh_config.get_hostnames():
1091+
hosts = ssh_client.get_config_hosts(ssh_config_path)
1092+
for host, hostname in hosts.items():
10981093
if verbose:
1099-
host_config = ssh_config.lookup(host)
11001094
click.secho("{} : {}".format(
1101-
host, host_config.get('hostname')))
1095+
host, hostname))
11021096
else:
11031097
click.secho(host)
11041098
sys.exit(0)
1099+
11051100
# Choose which ever one has a valid value.
11061101
database = dbname or database
11071102

@@ -1153,7 +1148,7 @@ def cli(database, user, host, port, socket, password, dbname,
11531148
port = uri.port
11541149

11551150
if ssh_config_host:
1156-
ssh_config = read_ssh_config(
1151+
ssh_config = ssh_client.read_config_file(
11571152
ssh_config_path
11581153
).lookup(ssh_config_host)
11591154
ssh_host = ssh_host if ssh_host else ssh_config.get('hostname')
@@ -1164,7 +1159,10 @@ def cli(database, user, host, port, socket, password, dbname,
11641159
ssh_key_filename = ssh_key_filename if ssh_key_filename else ssh_config.get(
11651160
'identityfile', [None])[0]
11661161

1167-
ssh_key_filename = ssh_key_filename and os.path.expanduser(ssh_key_filename)
1162+
if ssh_host:
1163+
mycli.ssh_client = create_ssh_client(
1164+
ssh_host, ssh_port, ssh_user, ssh_password, ssh_key_filename
1165+
)
11681166

11691167
mycli.connect(
11701168
database=database,
@@ -1175,11 +1173,6 @@ def cli(database, user, host, port, socket, password, dbname,
11751173
socket=socket,
11761174
local_infile=local_infile,
11771175
ssl=ssl,
1178-
ssh_user=ssh_user,
1179-
ssh_host=ssh_host,
1180-
ssh_port=ssh_port,
1181-
ssh_password=ssh_password,
1182-
ssh_key_filename=ssh_key_filename
11831176
)
11841177

11851178
mycli.logger.debug('Launch Params: \n'
@@ -1298,26 +1291,5 @@ def edit_and_execute(event):
12981291
buff.open_in_editor(validate_and_handle=False)
12991292

13001293

1301-
def read_ssh_config(ssh_config_path):
1302-
ssh_config = paramiko.config.SSHConfig()
1303-
try:
1304-
with open(ssh_config_path) as f:
1305-
ssh_config.parse(f)
1306-
# Paramiko prior to version 2.7 raises Exception on parse errors.
1307-
# In 2.7 it has become paramiko.ssh_exception.SSHException,
1308-
# but let's catch everything for compatibility
1309-
except Exception as err:
1310-
click.secho(
1311-
f'Could not parse SSH configuration file {ssh_config_path}:\n{err} ',
1312-
err=True, fg='red'
1313-
)
1314-
sys.exit(1)
1315-
except FileNotFoundError as e:
1316-
click.secho(str(e), err=True, fg='red')
1317-
sys.exit(1)
1318-
else:
1319-
return ssh_config
1320-
1321-
13221294
if __name__ == "__main__":
13231295
cli()

mycli/packages/ssh_client/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .client import get_config_hosts, create_ssh_client, SSHException, read_config_file

mycli/packages/ssh_client/client.py

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""
2+
A very thin wrapper around paramiko, mostly to keep all SSH-related functionality in one place
3+
"""
4+
from io import open
5+
6+
try:
7+
import paramiko
8+
except ImportError:
9+
from mycli.packages.paramiko_stub import paramiko
10+
11+
12+
class SSHException(Exception):
13+
pass
14+
15+
16+
def get_config_hosts(config_path):
17+
config = read_config_file(config_path)
18+
return {
19+
host: config.lookup(host).get("hostname") for host in config.get_hostnames()
20+
}
21+
22+
23+
def create_ssh_client(ssh_host, ssh_port, ssh_user, ssh_password=None, ssh_key_filename=None) -> paramiko.SSHClient:
24+
client = paramiko.SSHClient()
25+
client.load_system_host_keys()
26+
client.set_missing_host_key_policy(paramiko.WarningPolicy())
27+
client.connect(
28+
ssh_host, ssh_port, ssh_user, password=ssh_password, key_filename=ssh_key_filename
29+
)
30+
return client
31+
32+
33+
def read_config_file(config_path) -> paramiko.SSHConfig:
34+
ssh_config = paramiko.config.SSHConfig()
35+
try:
36+
with open(config_path) as f:
37+
ssh_config.parse(f)
38+
# Paramiko prior to version 2.7 raises Exception on parse errors.
39+
# In 2.7 it has become paramiko.ssh_exception.SSHException,
40+
# but let's catch everything for compatibility
41+
except Exception as err:
42+
raise SSHException(
43+
f"Could not parse SSH configuration file {config_path}:\n{err} ",
44+
)
45+
except FileNotFoundError as e:
46+
raise SSHException(str(e))
47+
return ssh_config

mycli/sqlexecute.py

+11-39
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,7 @@
66
from pymysql.converters import (convert_mysql_timestamp, convert_datetime,
77
convert_timedelta, convert_date, conversions,
88
decoders)
9-
try:
10-
import paramiko
11-
except ImportError:
12-
from mycli.packages.paramiko_stub import paramiko
9+
1310

1411
_logger = logging.getLogger(__name__)
1512

@@ -18,6 +15,7 @@
1815
FIELD_TYPE.NULL: type(None)
1916
})
2017

18+
2119
class SQLExecute(object):
2220

2321
databases_query = '''SHOW DATABASES'''
@@ -41,8 +39,8 @@ class SQLExecute(object):
4139
order by table_name,ordinal_position'''
4240

4341
def __init__(self, database, user, password, host, port, socket, charset,
44-
local_infile, ssl, ssh_user, ssh_host, ssh_port, ssh_password,
45-
ssh_key_filename):
42+
local_infile, ssl,
43+
ssh_client=None):
4644
self.dbname = database
4745
self.user = user
4846
self.password = password
@@ -54,17 +52,12 @@ def __init__(self, database, user, password, host, port, socket, charset,
5452
self.ssl = ssl
5553
self._server_type = None
5654
self.connection_id = None
57-
self.ssh_user = ssh_user
58-
self.ssh_host = ssh_host
59-
self.ssh_port = ssh_port
60-
self.ssh_password = ssh_password
61-
self.ssh_key_filename = ssh_key_filename
55+
self.ssh_client = ssh_client
56+
6257
self.connect()
6358

6459
def connect(self, database=None, user=None, password=None, host=None,
65-
port=None, socket=None, charset=None, local_infile=None,
66-
ssl=None, ssh_host=None, ssh_port=None, ssh_user=None,
67-
ssh_password=None, ssh_key_filename=None):
60+
port=None, socket=None, charset=None, local_infile=None, ssl=None):
6861
db = (database or self.dbname)
6962
user = (user or self.user)
7063
password = (password or self.password)
@@ -74,11 +67,6 @@ def connect(self, database=None, user=None, password=None, host=None,
7467
charset = (charset or self.charset)
7568
local_infile = (local_infile or self.local_infile)
7669
ssl = (ssl or self.ssl)
77-
ssh_user = (ssh_user or self.ssh_user)
78-
ssh_host = (ssh_host or self.ssh_host)
79-
ssh_port = (ssh_port or self.ssh_port)
80-
ssh_password = (ssh_password or self.ssh_password)
81-
ssh_key_filename = (ssh_key_filename or self.ssh_key_filename)
8270
_logger.debug(
8371
'Connection DB Params: \n'
8472
'\tdatabase: %r'
@@ -88,14 +76,8 @@ def connect(self, database=None, user=None, password=None, host=None,
8876
'\tsocket: %r'
8977
'\tcharset: %r'
9078
'\tlocal_infile: %r'
91-
'\tssl: %r'
92-
'\tssh_user: %r'
93-
'\tssh_host: %r'
94-
'\tssh_port: %r'
95-
'\tssh_password: %r'
96-
'\tssh_key_filename: %r',
79+
'\tssl: %r',
9780
db, user, host, port, socket, charset, local_infile, ssl,
98-
ssh_user, ssh_host, ssh_port, ssh_password, ssh_key_filename
9981
)
10082
conv = conversions.copy()
10183
conv.update({
@@ -107,26 +89,16 @@ def connect(self, database=None, user=None, password=None, host=None,
10789

10890
defer_connect = False
10991

110-
if ssh_host:
111-
defer_connect = True
112-
11392
conn = pymysql.connect(
11493
database=db, user=user, password=password, host=host, port=port,
11594
unix_socket=socket, use_unicode=True, charset=charset,
11695
autocommit=True, client_flag=pymysql.constants.CLIENT.INTERACTIVE,
11796
local_infile=local_infile, conv=conv, ssl=ssl, program_name="mycli",
118-
defer_connect=defer_connect
97+
defer_connect=self.ssh_client is not None
11998
)
12099

121-
if ssh_host:
122-
client = paramiko.SSHClient()
123-
client.load_system_host_keys()
124-
client.set_missing_host_key_policy(paramiko.WarningPolicy())
125-
client.connect(
126-
ssh_host, ssh_port, ssh_user, ssh_password,
127-
key_filename=ssh_key_filename
128-
)
129-
chan = client.get_transport().open_channel(
100+
if self.ssh_client:
101+
chan = self.ssh_client.get_transport().open_channel(
130102
'direct-tcpip',
131103
(host, port),
132104
('0.0.0.0', 0),

test/conftest.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import pytest
2+
3+
from mycli.packages.ssh_client import create_ssh_client
24
from .utils import (HOST, USER, PASSWORD, PORT, CHARSET, create_db,
35
db_connection, SSH_USER, SSH_HOST, SSH_PORT)
46
import mycli.sqlexecute
@@ -21,9 +23,13 @@ def cursor(connection):
2123

2224
@pytest.fixture
2325
def executor(connection):
26+
if SSH_HOST:
27+
ssh_client = create_ssh_client(SSH_HOST, SSH_PORT, SSH_USER)
28+
else:
29+
ssh_client = None
30+
2431
return mycli.sqlexecute.SQLExecute(
2532
database='_test_db', user=USER,
2633
host=HOST, password=PASSWORD, port=PORT, socket=None, charset=CHARSET,
27-
local_infile=False, ssl=None, ssh_user=SSH_USER, ssh_host=SSH_HOST,
28-
ssh_port=SSH_PORT, ssh_password=None, ssh_key_filename=None
34+
local_infile=False, ssl=None, ssh_client=ssh_client
2935
)

0 commit comments

Comments
 (0)