Skip to content

Commit

Permalink
Remove one call to get_ssh_client()
Browse files Browse the repository at this point in the history
Since in ssh_reachable() we already get a SSH client connection,
let's save it in the (unused so far) _ssh_client var.
Then reuse it, in _scp() command.
  • Loading branch information
Yaniv Kaul committed Apr 25, 2018
1 parent 306d9aa commit fe3baaa
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions lago/plugins/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def extract_paths(self, paths, ignore_nopath):
:exc:`~lago.plugins.vm.ExtractPathError`: on all other failures.
"""
if self.vm.alive() and self.vm.ssh_reachable(
tries=5, propagate_fail=False
tries=5, propagate_fail=False
):
self._extract_paths_scp(paths=paths, ignore_nopath=ignore_nopath)
else:
Expand Down Expand Up @@ -542,7 +542,6 @@ def ssh_script(self, path, show_output=True):
def alive(self):
return self.state() == 'running'

@check_alive
def ssh_reachable(self, tries=None, propagate_fail=True):
"""
Check if the VM is reachable with ssh
Expand All @@ -558,7 +557,7 @@ def ssh_reachable(self, tries=None, propagate_fail=True):
"""

try:
ssh.get_ssh_client(
self._ssh_client = ssh.get_ssh_client(
ip_addr=self.ip(),
host_name=self.name(),
ssh_tries=tries,
Expand Down Expand Up @@ -686,19 +685,23 @@ def _normalize_spec(cls, spec):

@contextlib.contextmanager
def _scp(self, propagate_fail=True):
client = ssh.get_ssh_client(
propagate_fail=propagate_fail,
ip_addr=self.ip(),
host_name=self.name(),
ssh_key=self.virt_env.prefix.paths.ssh_id_rsa(),
username=self._spec.get('ssh-user'),
password=self._spec.get('ssh-password'),
)
if self._ssh_client is not None:
client = self._ssh_client
else:
client = ssh.get_ssh_client(
propagate_fail=propagate_fail,
ip_addr=self.ip(),
host_name=self.name(),
ssh_key=self.virt_env.prefix.paths.ssh_id_rsa(),
username=self._spec.get('ssh-user'),
password=self._spec.get('ssh-password'),
)
scp = SCPClient(client.get_transport())
try:
yield scp
finally:
client.close()
self._ssh_client = None

def _detect_service_provider(self):
LOGGER.debug('Detecting service provider for %s', self.name())
Expand Down

0 comments on commit fe3baaa

Please sign in to comment.