Skip to content

Commit

Permalink
Merge pull request #4 from weka/sep24_1
Browse files Browse the repository at this point in the history
Rework to use fabric
  • Loading branch information
vince-weka authored Sep 30, 2024
2 parents 6187ded + fe0bed2 commit bb921a6
Showing 1 changed file with 83 additions and 95 deletions.
178 changes: 83 additions & 95 deletions wekapyutils/wekassh.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
#
import getpass
import os
from _socket import gaierror
from logging import getLogger

import paramiko
from scp import SCPClient
import fabric
#import paramiko
#from scp import SCPClient

from wekapyutils.sthreads import threaded, default_threader

Expand All @@ -21,47 +23,26 @@ def __str__(self):


class CommandOutput(object):
def __init__(self, status, stdout, stderr, exception):
def __init__(self, status, stdout, stderr, exception=None):
self.status = status
self.stdout = stdout
self.stderr = stderr
self.exception = exception

def __str__(self):
return f"status={self.status}, stdout={self.stdout}, stderr={self.stderr}, exception={self.exception}"

class RemoteServer(paramiko.SSHClient):
def __init__(self, hostname):
super().__init__()
self._sshconfig = paramiko.SSHConfig()
self._config_file = True
self.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.load_system_host_keys()

# handle missing config file
try:
fp = open(os.path.expanduser('~/.ssh/config'))
except IOError:
self.config_file = False
else:
try:
self._sshconfig.parse(fp)
except Exception as exc: # malformed config file?
log.critical(exc)
raise

class RemoteServer():
def __init__(self, hostname):
self.output = None
#self.connection = fabric.Connection(hostname)
self.connection = None
self._hostname = hostname
self.exc = None
self.user = ""
self.password = ""
self.connected = False
self.hostconfig = self._sshconfig.lookup(self._hostname)
if "user" in self.hostconfig:
self.user = self.hostconfig["user"]
else:
self.user = getpass.getuser()
self.password = "" # was None, but on linux it produces an error

if "identityfile" in self.hostconfig:
self.key_filename = self.hostconfig["identityfile"][0] # only take the first match, like OpenSSH
else:
self.key_filename = None

def ask_for_credentials(self):
print(f"Enter credentials for server {self._hostname}:")
Expand All @@ -75,82 +56,71 @@ def ask_for_credentials(self):

def connect(self):
success = False
self.kwargs = {"forward_agent": True}
while not success:
self.exc = None
self.kwargs = dict()

if self.user is not None:
self.kwargs["username"] = self.user
if self.password is not None and len(self.password) > 0:
self.kwargs["password"] = self.password

# don't give key_filename if they've provided a password
if self.key_filename is not None and "password" not in self.kwargs:
self.kwargs["key_filename"] = self.key_filename
else:
self.kwargs["key_filename"] = None
# self.kwargs["look_for_keys"] = True # actually the default...

try:
super().connect(self._hostname, **self.kwargs)
success = True
except paramiko.ssh_exception.AuthenticationException as exc:
log.critical(f"Authentication error opening ssh session to {self._hostname}: {exc}")
self.exc = AuthenticationException()
#self.connection = fabric.Connection(self._hostname, forward_agent=True)
self.connection = fabric.Connection(self._hostname, **self.kwargs)
result = self.connection.open()
self.connected = True
except gaierror as exc:
log.error(f"Error connecting to {self._hostname}: hostname not found")
self.connected = False
return
except Exception as exc:
log.critical(f"Exception opening ssh session to {self._hostname}: {exc}")
self.exc = exc
# ok, it's a gross hack, but we need to know if we're interactive or not
# ___interactive is assumed True, parallel() sets it to False
if not success:
if getattr(self, "___interactive", True):
self.ask_for_credentials()
else:
return # bail out if not interactive and error
self.connected = True
log.error(f"Error connecting to {self._hostname}: {exc}")
self.user = self.connection.user
self.ask_for_credentials()
connect_kwargs = {"password": self.password, "key_filename": []}
del self.connection
self.connection = fabric.Connection(self._hostname, user=self.user, connect_kwargs=connect_kwargs)
result = self.connection.open()
return self.connection

def close(self):
self.end_unending() # kills the fio --server process
super().close()
self.connection.close()
#super().close()

def get_transport(self):
return self.connection.transport

def scp(self, source, dest):
log.info(f"copying {source} to {self._hostname}")
with SCPClient(self.get_transport()) as scp:
scp.put(source, recursive=True, remote_path=dest)
self.connection.put(source, dest)

def run(self, cmd):
if self.connected == False:
"""
:param cmd:
:type cmd:
:return:returns a CommandOutput object with the results of the command
and also stores it in self.output
:rtype:
"""
if self.connection is None:
log.error(f'Cannot run command - not connected to host {self._hostname}')
self.output = CommandOutput(1,"","",None)
return self.output
exc = None
status = None
return
try:
stdin, stdout, stderr = self.exec_command(cmd, get_pty=True)
status = stdout.channel.recv_exit_status()
stdout.flush()
response = stdout.read().decode("utf-8")
error = stderr.read().decode("utf-8")
self.last_output = {'status': status, 'response': response, 'error': error, "exc": None}
if status != 0:
log.debug(f"run: Bad return code from {cmd[:100]}: {status}. Output is:")
log.debug(f"stdout is {response[:4000]}")
log.debug(f"stderr is {error[:4000]}")
else:
log.debug(f"run: 'status {status}, stdout {len(response)} bytes, stderr {len(error)} bytes")
result = self.connection.run(cmd, hide=True)
#self.output = CommandOutput(result.return_code, result.stdout, result.stderr, exc)
except gaierror as exc:
log.error(f"Error connecting to {self._hostname}: hostname not found")
self.output = CommandOutput(127, "hostname not found", "", exc)
except Exception as exc:
log.debug(f"run (Exception): '{cmd[:100]}', exception='{exc}'")
self.output = CommandOutput(1,"","",None)
return self.output
self.output = CommandOutput(status, response, error, exc)
result = exc.result
self.output = CommandOutput(result.return_code, result.stdout, result.stderr, exc)
else:
self.output = CommandOutput(result.return_code, result.stdout, result.stderr)
return self.output

def _linux_to_dict(self, separator):
output = dict()
if self.last_output['status'] != 0:
log.debug(f"last output = {self.last_output}")
if self.output['status'] != 0:
log.debug(f"last output = {self.output}")
raise Exception
lines = self.last_output['response'].split('\n')
lines = self.output['response'].split('\n')
for line in lines:
if len(line) != 0:
line_split = line.split(separator)
Expand All @@ -161,7 +131,7 @@ def _linux_to_dict(self, separator):
def _count_cpus(self):
""" count up the cpus; 0,1-4,7,etc """
num_cores = 0
cpulist = self.last_output['response'].strip(' \n').split(',')
cpulist = self.output.stdout.strip(' \n').split(',')
for item in cpulist:
if '-' in item:
parts = item.split('-')
Expand All @@ -186,8 +156,8 @@ def gather_facts(self, weka):

if weka:
self.run('mount | grep wekafs')
log.debug(f"{self.last_output}")
if len(self.last_output['response']) == 0:
log.debug(f"{self.output}")
if len(self.output['response']) == 0:
log.debug(f"{self._hostname} does not have a weka filesystem mounted.")
self.weka_mounted = False
else:
Expand All @@ -197,22 +167,23 @@ def file_exists(self, path):
""" see if a file exists on another server """
log.debug(f"checking for presence of file {path} on server {self._hostname}")
self.run(f"if [ -f '{path}' ]; then echo 'True'; else echo 'False'; fi")
strippedstr = self.last_output['response'].strip(' \n')
strippedstr = self.output['response'].strip(' \n')
log.debug(f"server responded with {strippedstr}")
if strippedstr == "True":
return True
else:
return False

def last_response(self):
return self.last_output['response'].strip(' \n')
return self.output

def __str__(self):
return self._hostname

def run_unending(self, command):
""" run a command that never ends - needs to be terminated by ^c or something """
transport = self.get_transport()
#transport = self.get_transport()
transport = self.connection.get_transport()
self.unending_session = transport.open_session()
self.unending_session.setblocking(0) # Set to non-blocking mode
self.unending_session.get_pty()
Expand Down Expand Up @@ -251,3 +222,20 @@ def pdsh(servers, command):
def pscp(servers, source, dest):
log.debug(f"setting up parallel copy to {servers}")
parallel(servers, RemoteServer.scp, source, dest)

if __name__ == '__main__':
test1 = RemoteServer("wms")
result = test1.connect()
result2 = test1.run("date")
print(result2)
print(result2.stdout)
test1.scp("wekassh2.py", "/tmp/wekassh2.py")

servers = [RemoteServer("wms"), RemoteServer("buckaroo"), RemoteServer("whorfin")]
parallel(servers, RemoteServer.connect)
parallel(servers, RemoteServer.run, "hostname")
default_threader.run()
print("done")
for i in servers:
print(i.last_response())
pass

0 comments on commit bb921a6

Please sign in to comment.