"""Handler for operations, such as "download", on ssh:// URLs"""
# allow for |-type UnionType declarations
from __future__ import annotations
import logging
import random
import sys
import time
from functools import partial
from math import floor
from pathlib import (
Path,
PurePosixPath,
)
from queue import (
Full,
Queue,
)
from typing import (
Any,
Dict,
IO,
)
from urllib.parse import (
urlparse,
ParseResult,
)
from more_itertools import consume
from datalad_next.consts import COPY_BUFSIZE
from datalad_next.config import ConfigManager
from datalad_next.runners import CommandError
from datalad_next.shell import (
FixedLengthResponseGeneratorPosix,
ShellCommandExecutor,
shell,
)
from .base import UrlOperations
from .exceptions import (
UrlOperationsRemoteError,
UrlOperationsResourceUnknown,
)
lgr = logging.getLogger('datalad.ext.next.ssh_url_operations')
__all__ = ['SshUrlOperations']
[docs]
class SshUrlOperations(UrlOperations):
"""Handler for operations on ``ssh://`` URLs
For downloading files, only servers that support execution of the commands
'ls -dln', 'awk', and 'cat' are supported. This includes a wide
range of operating systems, including devices that provide these commands
via the 'busybox' software.
.. note::
Any instance of ``SshUrlOperations`` must be deleted before ending the
program, otherwise python might not exit. The reason is, that
``SshUrlOperations`` retains and reuses SSH connections for subsequent
command execution. Each connection has two threads associated with it.
Those threads are only terminated when the connection is closed. The
destructor of ``SshUrlOperations`` closes all connections and terminates
all associated threads.
"""
def __init__(self, *, cfg: ConfigManager | None = None):
super().__init__(cfg=cfg)
self.ssh_shells: dict[tuple[str, ...], tuple[ShellCommandExecutor, Any]] = dict()
def __del__(self):
for ssh_executor, context in self.ssh_shells.values():
ssh_executor.close()
context.__exit__(None, None, None)
@staticmethod
def _check_return_code(return_code: int | None, url: str, msg: str = ''):
if return_code == 244:
# this is the special code for a file-not-found
raise UrlOperationsResourceUnknown(url, message=msg)
elif return_code != 0:
raise UrlOperationsRemoteError(
url,
message=f'ssh command returned {return_code}'
+ f': {msg}' if msg else ''
)
[docs]
def ssh_shell_for(self,
url: str) -> ShellCommandExecutor:
"""Get a ShellCommandExecutor for the url (cached or newly created)"""
open_args = ssh_url2openargs(url, self.cfg)[0]
key = tuple(open_args)
if key not in self.ssh_shells:
context = shell(['ssh'] + open_args)
try:
ssh_executor = context.__enter__()
except CommandError as e:
context.__exit__(None, None, None)
raise UrlOperationsRemoteError(url) from e
self.ssh_shells[key] = (ssh_executor, context)
return self.ssh_shells[key][0]
[docs]
def close_shell_for(self, url: str):
"""Close the ShellCommandExecutor for the url and remove it"""
open_args = ssh_url2openargs(url, self.cfg)[0]
key = tuple(open_args)
if key in self.ssh_shells:
ssh_executor, context = self.ssh_shells.pop(key)
ssh_executor.close()
context.__exit__(None, None, None)
[docs]
def stat(self,
url: str,
*,
credential: str | None = None,
timeout: float | None = None) -> Dict:
"""Gather information on a URL target, without downloading it
See :meth:`datalad_next.url_operations.UrlOperations.stat`
for parameter documentation and exception behavior.
"""
# Check whether a readable file exists at the path. If not signal a
# dedicated 244 return code. This allows the user to distinguish the
# absence of a readable file from other errors, e.g. from an error in
# awk. Only a missing file would yield 244. A ssh-connection problem
# would lead to a 255 error (and a closed connection).
stat_cmd = """
ret() {{ return $1; }}
test -r {fpath}
if [ $? -eq 0 ]; then
LC_ALL=C ls -dln -- {fpath} | awk '{{print $5; exit}}'
else
ret 244
fi"""
cmd = self.format_cmd(stat_cmd, url)
ssh = self.ssh_shell_for(url)
result = ssh(cmd)
self._check_return_code(result.returncode, url, result.stderr.decode())
return {'content-length': int(result.stdout)}
[docs]
def delete(self,
url: str,
*,
credential: str | None = None,
timeout: float | None = None) -> Dict:
"""Delete the target of a shh://-URL
The target can be a file or a directory. `delete` will attempt to
delete write protected targets (by setting write permissions). If
the target is a directory, the complete directory and all its
content will be deleted. `delete` will not modify the permissions
of the parent of the target. That means, it will not delete a target
in a write protected directory, but it will empty target, if target is
a directory.
See :meth:`datalad_next.url_operations.UrlOperations.delete`
for parameter documentation and exception behavior.
Raises
------
UrlOperationsResourceUnknown
For deletion targets found absent.
"""
delete_cmd = """
ret() {{ return $1; }}
if [ -f {fpath} ]; then
chmod u+w {fpath}
rm -f {fpath}
elif [ -d {fpath} ]; then
chmod -R u+wx {fpath}
rm -rf {fpath}
else
ret 244
fi"""
cmd = self.format_cmd(delete_cmd, url)
ssh = self.ssh_shell_for(url)
result = ssh(cmd)
self._check_return_code(result.returncode, url, result.stderr.decode())
return {}
[docs]
def download(self,
from_url: str,
to_path: Path | None,
*,
# unused, but theoretically could be used to
# obtain escalated/different privileges on a system
# to gain file access
credential: str | None = None,
hash: list[str] | None = None,
timeout: float | None = None) -> Dict:
"""Download a file by streaming it through an SSH connection.
On the server-side, the file size is determined and sent. Afterwards
the file content is sent via `cat` to the SSH client.
See :meth:`datalad_next.url_operations.UrlOperations.download`
for parameter documentation and exception behavior.
"""
hasher = self._get_hasher(hash)
progress_id = self._get_progress_id(from_url, str(to_path))
# get the size of the file to download
stat = self.stat(from_url, credential=credential, timeout=timeout)
expected_size = stat['content-length']
# get a shell command executor and a fixed length response generator
ssh = self.ssh_shell_for(from_url)
response_generator = FixedLengthResponseGeneratorPosix(
ssh.stdout,
expected_size
)
dst_fp = sys.stdout.buffer \
if to_path is None \
else open(to_path, 'wb')
# Localize variable access to minimize overhead
dst_fp_write = dst_fp.write
# We already know that file exists, so we can just cat it.
cmd = self.format_cmd('cat {fpath}', from_url)
result_generator = ssh.start(
cmd,
response_generator=response_generator
)
# We do not use the `shell.operations.posix.download`-method here
# because we need access to every individual chunk in order to calculate
# the hash on the fly.
for chunk in self._with_progress(
result_generator,
progress_id=progress_id,
label='downloading',
expected_size=expected_size,
start_log_msg=('Download %s to %s', from_url, to_path),
end_log_msg=('Finished download',),
update_log_msg=('Downloaded chunk',)
):
# write data
dst_fp_write(chunk)
# compute hash simultaneously
hasher.update(chunk)
if dst_fp and to_path is not None:
dst_fp.close()
self._check_return_code(
result_generator.returncode,
from_url,
''.join(result_generator.stderr_deque)
)
return {
**stat,
**hasher.get_hexdigest(),
}
[docs]
def upload(self,
from_path: Path | None,
to_url: str,
*,
credential: str | None = None,
hash: list[str] | None = None,
timeout: float | None = None) -> Dict:
"""Upload a file by streaming it through an SSH connection.
It, more or less, runs `ssh <host> 'cat > <path>'` or
`ssh <host> 'head -c <file-size> > <path>'` on the remote side.
See :meth:`datalad_next.url_operations.UrlOperations.upload`
for parameter documentation and exception behavior.
"""
if from_path is None:
source_name = '<STDIN>'
return self._perform_upload(
src_fp=sys.stdin.buffer,
source_name=source_name,
to_url=to_url,
hash_names=hash,
expected_size=None,
timeout=timeout,
)
else:
# die right away, if we lack read permissions or there is no file
with from_path.open("rb") as src_fp:
return self._perform_upload(
src_fp=src_fp,
source_name=str(from_path),
to_url=to_url,
hash_names=hash,
expected_size=from_path.stat().st_size,
timeout=timeout,
)
def _perform_upload(self,
src_fp: IO,
source_name: str,
to_url: str,
hash_names: list[str] | None,
expected_size: int | None,
timeout: float | None) -> dict:
hasher = self._get_hasher(hash_names)
# we use a queue to implement timeouts.
# we limit the queue to few items in order to `make queue.put()`
# block relatively quickly, and thereby have the progress report
# actually track the upload, i.e. the feeding of the stdin pipe
# of the ssh-process, and not just the feeding of the
# queue.
# If we did not support timeouts, we could just use the following
# as `input`-iterable for `iter_subproc`:
#
# `iter(partial(src_fp.read, COPY_BUFSIZE), b'')
#
upload_queue: Queue = Queue(maxsize=2)
if expected_size:
read_cmd = f"head -c {expected_size}"
else:
read_cmd = "cat"
cmd = self.format_cmd(
# copy the file to its destination location with a randomized
# name, and move it to its final location after upload. This
# way, upload appears atomic, i.e. no half uploaded file will
# be seen at the destination URL
# leave special exit code when writing or moving fails, but not
# the general SSH access
"ret() {{ return $1; }}; ( mkdir -p '{fdir}' "
f"&& {read_cmd} "
"> '{fpath}.transfer-{nonce}' "
"&& mv '{fpath}.transfer-{nonce}' '{fpath}' ) || ret 243",
to_url,
)
progress_id = self._get_progress_id(source_name, to_url)
ssh = self.ssh_shell_for(to_url)
result_generator = ssh.start(
cmd,
stdin=self._with_progress(
iter(upload_queue.get, None),
progress_id=progress_id,
label='uploading',
expected_size=expected_size,
start_log_msg=('Upload %s to %s', source_name, to_url),
end_log_msg=('Finished upload',),
update_log_msg=('Uploaded chunk',)
)
)
try:
upload_size = 0
for chunk in iter(partial(src_fp.read, COPY_BUFSIZE), b''):
# we are just putting stuff in the queue, and rely on
# its maxsize to cause it to block the next call to
# have the progress reports be anyhow valid, we also
# rely on put-timeouts to implement timeout.
upload_queue.put(chunk, timeout=timeout)
# compute hash simultaneously
hasher.update(chunk)
upload_size += len(chunk)
upload_queue.put(None, timeout=timeout)
except Full:
# we had a timeout while uploading
raise TimeoutError(f'timeout while executing: {cmd}')
if expected_size:
consume(result_generator)
else:
# If the remote shell reads from stdin, its stdin has to be close
# for the upload-command to terminate
if expected_size is None:
ssh.close()
consume(result_generator)
# stdin of the shell was closed, it cannot be used anymore.
self.close_shell_for(to_url)
self._check_return_code(
result_generator.returncode,
to_url,
''.join(result_generator.stderr_deque)
)
return {
**hasher.get_hexdigest(),
# return how much was copied. we could compare with
# `expected_size` and error on mismatch, but not all
# sources can provide that (e.g. stdin)
'content-length': upload_size
}
class _SshCommandBuilder:
def __init__(
self,
url: str,
cfg: ConfigManager,
):
self.ssh_args, self._parsed = ssh_url2openargs(url, cfg)
self.ssh_args.extend(('-e', 'none'))
# make sure the essential pieces exist
assert self._parsed.path
time_stamp = time.time()
self.substitutions = dict(
fdir=str(PurePosixPath(self._parsed.path).parent),
fpath=self._parsed.path,
nonce=(
str(random.randint(1000000000, 9999999999))
+ '_'
+ str(time_stamp - floor(time_stamp))[2:0]
)
)
def substitute(self, payload_cmd: str) -> str:
return payload_cmd.format(**self.substitutions)
def ssh_url2openargs(
url: str,
cfg: ConfigManager,
) -> tuple[list[str], ParseResult]:
"""Helper to report ssh-open arguments from a URL and config
Returns a tuple with the argument list and the parsed URL.
"""
args: list[str] = list()
parsed = urlparse(url)
# make sure the essential pieces exist
assert parsed.hostname
for opt, arg in (('-p', parsed.port),
('-l', parsed.username),
('-i', cfg.get('datalad.ssh.identityfile'))):
if arg:
# f-string, because port is not str
args.extend((opt, f'{arg}'))
# we could also use .netloc here and skip -p/-l above
args.append(parsed.hostname)
return args, parsed