"""Handler for operations, such as "download", on ssh:// URLs"""
# allow for |-type UnionType declarations
from __future__ import annotations
import logging
import sys
from functools import partial
from itertools import chain
from pathlib import (
Path,
PurePosixPath,
)
from queue import (
Full,
Queue,
)
from typing import (
Dict,
Generator,
IO,
cast,
)
from urllib.parse import (
urlparse,
ParseResult,
)
from datalad_next.consts import COPY_BUFSIZE
from datalad_next.config import ConfigManager
from datalad_next.itertools import align_pattern
from datalad_next.runners import (
iter_subproc,
CommandError,
)
from .base import UrlOperations
from .exceptions import (
UrlOperationsRemoteError,
UrlOperationsResourceUnknown,
)
lgr = logging.getLogger('datalad.ext.next.ssh_url_operations')
__all__ = ['SshUrlOperations']
[docs]
class SshUrlOperations(UrlOperations):
"""Handler for operations on ``ssh://`` URLs
For downloading files, only servers that support execution of the commands
'printf', 'ls -nl', 'awk', and 'cat' are supported. This includes a wide
range of operating systems, including devices that provide these commands
via the 'busybox' software.
.. note::
The present implementation does not support SSH connection multiplexing,
(re-)authentication is performed for each request. This limitation is
likely to be removed in the future, and connection multiplexing
supported where possible (non-Windows platforms).
"""
# first try ls'ing the path, and catch a missing path with a dedicated 244
# exit code, to be able to distinguish the original exit=2 that ls-call
# from a later exit=2 from awk in case of a "fatal error".
# when executed through ssh, only a missing file would yield 244, while
# a connection error or other problem unrelated to the present of a file
# would a different error code (255 in case of a connection error)
_stat_cmd = "printf \"\\1\\2\\3\"; ls '{fpath}' &> /dev/null " \
"&& ls -nl '{fpath}' | awk 'BEGIN {{ORS=\"\\1\"}} {{print $5}}' " \
"|| exit 244"
_cat_cmd = "cat '{fpath}'"
@staticmethod
def _check_return_code(return_code: int, url: str):
# At this point the subprocess has either exited, was terminated, or
# was killed.
if return_code == 244:
# this is the special code for a file-not-found
raise UrlOperationsResourceUnknown(url)
elif return_code != 0:
raise UrlOperationsRemoteError(
url,
message=f'ssh process returned {return_code}'
)
[docs]
def stat(self,
url: str,
*,
credential: str | None = None,
timeout: float | None = None) -> Dict:
"""Gather information on a URL target, without downloading it
See :meth:`datalad_next.url_operations.UrlOperations.stat`
for parameter documentation and exception behavior.
"""
ssh_cat = _SshCommandBuilder(url, self.cfg)
cmd = ssh_cat.get_cmd(SshUrlOperations._stat_cmd)
try:
with iter_subproc(cmd) as stream:
try:
props = self._get_props(url, stream)
except StopIteration:
# we did not receive all data that should be sent, if a
# remote file exists. This indicates a non-existing
# resource or some other problem. The remotely executed
# command should signal the error via a non-zero exit code.
# That will trigger a `CommandError` below.
pass
except CommandError:
self._check_return_code(stream.returncode, url)
return {k: v for k, v in props.items() if not k.startswith('_')}
def _get_props(self, url, stream: Generator) -> dict:
# Any stream must start with this magic marker, or we do not
# recognize what is happening
# after this marker, the server will send the size of the
# to-be-downloaded file in bytes, followed by another magic
# b'\1', and the file content after that.
magic_marker = b'\1\2\3'
# use the `align_pattern` iterable to guarantees, that the magic
# marker is always contained in a complete chunk.
aligned_stream = align_pattern(stream, magic_marker)
# Because the stream should start with the pattern, the first chunk of
# the aligned stream must contain it.
# We know that the stream will deliver bytes, cast the result
# accordingly.
chunk = cast(bytes, next(aligned_stream))
if chunk[:len(magic_marker)] != magic_marker:
raise RuntimeError("Protocol error: report header not received")
chunk = chunk[len(magic_marker):]
# We are done with the aligned stream, use the original stream again.
# This is possible because `align_pattern` does not cache any data
# after a `yield`.
del aligned_stream
# The length is transferred now and terminated by b'\x01'.
while b'\x01' not in chunk:
chunk += next(stream)
marker_index = chunk.index(b'\x01')
expected_size = int(chunk[:marker_index])
chunk = chunk[marker_index + 1:]
props = {
'content-length': expected_size,
# go back to the original iterator, no need to keep looking for
# a pattern
'_stream': chain([chunk], stream) if chunk else stream
}
return props
[docs]
def download(self,
from_url: str,
to_path: Path | None,
*,
# unused, but theoretically could be used to
# obtain escalated/different privileges on a system
# to gain file access
credential: str | None = None,
hash: list[str] | None = None,
timeout: float | None = None) -> Dict:
"""Download a file by streaming it through an SSH connection.
On the server-side, the file size is determined and sent. Afterwards
the file content is sent via `cat` to the SSH client.
See :meth:`datalad_next.url_operations.UrlOperations.download`
for parameter documentation and exception behavior.
"""
# this is pretty much shutil.copyfileobj() with the necessary
# wrapping to perform hashing and progress reporting
hasher = self._get_hasher(hash)
progress_id = self._get_progress_id(from_url, str(to_path))
dst_fp = None
ssh_cat = _SshCommandBuilder(from_url, self.cfg)
cmd = ssh_cat.get_cmd(f'{SshUrlOperations._stat_cmd}; {SshUrlOperations._cat_cmd}')
try:
with iter_subproc(cmd) as stream:
try:
props = self._get_props(from_url, stream)
expected_size = props['content-length']
# The stream might have changed due to not yet processed, but
# fetched data, that is now chained in front of it. Therefore we
# get the updated stream from the props
download_stream = props.pop('_stream')
dst_fp = sys.stdout.buffer \
if to_path is None \
else open(to_path, 'wb')
# Localize variable access to minimize overhead
dst_fp_write = dst_fp.write
# download can start
for chunk in self._with_progress(
download_stream,
progress_id=progress_id,
label='downloading',
expected_size=expected_size,
start_log_msg=('Download %s to %s', from_url, to_path),
end_log_msg=('Finished download',),
update_log_msg=('Downloaded chunk',)
):
# write data
dst_fp_write(chunk)
# compute hash simultaneously
hasher.update(chunk)
except StopIteration:
# we did not receive all data that should be sent, if a
# remote file exists. This indicates a non-existing
# resource or some other problem. The remotely executed
# command should signal the error via a non-zero exit code.
# That will trigger a `CommandError` below.
pass
except CommandError:
self._check_return_code(stream.returncode, from_url)
finally:
if dst_fp and to_path is not None:
dst_fp.close()
return {
**props,
**hasher.get_hexdigest(),
}
[docs]
def upload(self,
from_path: Path | None,
to_url: str,
*,
credential: str | None = None,
hash: list[str] | None = None,
timeout: float | None = None) -> Dict:
"""Upload a file by streaming it through an SSH connection.
It, more or less, runs `ssh <host> 'cat > <path>'`.
See :meth:`datalad_next.url_operations.UrlOperations.upload`
for parameter documentation and exception behavior.
"""
if from_path is None:
source_name = '<STDIN>'
return self._perform_upload(
src_fp=sys.stdin.buffer,
source_name=source_name,
to_url=to_url,
hash_names=hash,
expected_size=None,
timeout=timeout,
)
else:
# die right away, if we lack read permissions or there is no file
with from_path.open("rb") as src_fp:
return self._perform_upload(
src_fp=src_fp,
source_name=str(from_path),
to_url=to_url,
hash_names=hash,
expected_size=from_path.stat().st_size,
timeout=timeout,
)
def _perform_upload(self,
src_fp: IO,
source_name: str,
to_url: str,
hash_names: list[str] | None,
expected_size: int | None,
timeout: float | None) -> dict:
hasher = self._get_hasher(hash_names)
# we use a queue to implement timeouts.
# we limit the queue to few items in order to `make queue.put()`
# block relatively quickly, and thereby have the progress report
# actually track the upload, i.e. the feeding of the stdin pipe
# of the ssh-process, and not just the feeding of the
# queue.
# If we did not support timeouts, we could just use the following
# as `input`-iterable for `iter_subproc`:
#
# `iter(partial(src_fp.read, COPY_BUFSIZE), b'')
#
upload_queue: Queue = Queue(maxsize=2)
cmd = _SshCommandBuilder(to_url, self.cfg).get_cmd(
# leave special exit code when writing fails, but not the
# general SSH access
"( mkdir -p '{fdir}' && cat > '{fpath}' ) || exit 244"
)
progress_id = self._get_progress_id(source_name, to_url)
try:
with iter_subproc(
cmd,
input=self._with_progress(
iter(upload_queue.get, None),
progress_id=progress_id,
label='uploading',
expected_size=expected_size,
start_log_msg=('Upload %s to %s', source_name, to_url),
end_log_msg=('Finished upload',),
update_log_msg=('Uploaded chunk',)
)
):
upload_size = 0
for chunk in iter(partial(src_fp.read, COPY_BUFSIZE), b''):
# we are just putting stuff in the queue, and rely on
# its maxsize to cause it to block the next call to
# have the progress reports be anyhow valid, we also
# rely on put-timeouts to implement timeout.
upload_queue.put(chunk, timeout=timeout)
# compute hash simultaneously
hasher.update(chunk)
upload_size += len(chunk)
upload_queue.put(None, timeout=timeout)
except CommandError as e:
self._check_return_code(e.returncode, to_url)
except Full:
if chunk != b'':
# we had a timeout while uploading
raise TimeoutError
return {
**hasher.get_hexdigest(),
# return how much was copied. we could compare with
# `expected_size` and error on mismatch, but not all
# sources can provide that (e.g. stdin)
'content-length': upload_size
}
class _SshCommandBuilder:
def __init__(
self,
url: str,
cfg: ConfigManager,
):
self.ssh_args, self._parsed = ssh_url2openargs(url, cfg)
self.ssh_args.extend(('-e', 'none'))
# make sure the essential pieces exist
assert self._parsed.path
self.substitutions = dict(
fdir=str(PurePosixPath(self._parsed.path).parent),
fpath=self._parsed.path,
)
def get_cmd(self,
payload_cmd: str,
) -> list[str]:
cmd = ['ssh']
cmd.extend(self.ssh_args)
cmd.append(payload_cmd.format(**self.substitutions))
return cmd
def ssh_url2openargs(
url: str,
cfg: ConfigManager,
) -> tuple[list[str], ParseResult]:
"""Helper to report ssh-open arguments from a URL and config
Returns a tuple with the argument list and the parsed URL.
"""
args: list[str] = list()
parsed = urlparse(url)
# make sure the essential pieces exist
assert parsed.hostname
for opt, arg in (('-p', parsed.port),
('-l', parsed.username),
('-i', cfg.get('datalad.ssh.identityfile'))):
if arg:
# f-string, because port is not str
args.extend((opt, f'{arg}'))
# we could also use .netloc here and skip -p/-l above
args.append(parsed.hostname)
return args, parsed