Source code for datalad.runner.nonasyncrunner

# emacs: -*- mode: python; py-indent-offset: 4; tab-width: 4; indent-tabs-mode: nil -*-
# ex: set sts=4 ts=4 sw=4 et:
# ## ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
#
#   See COPYING file distributed along with the datalad package for the
#   copyright and license terms.
#
# ## ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
"""
Thread based subprocess execution with stdout and stderr passed to protocol objects
"""

from __future__ import annotations

import enum
import logging
import subprocess
import threading
import time
from collections import deque
from collections.abc import Generator
from queue import (
    Empty,
    Queue,
)
from subprocess import Popen
from typing import (
    IO,
    Any,
    Optional,
)

from datalad.utils import on_windows

from .exception import CommandError
from .protocol import (
    GeneratorMixIn,
    WitlessProtocol,
)
from .runnerthreads import (
    IOState,
    ReadThread,
    WaitThread,
    WriteThread,
    _try_close,
)


__docformat__ = 'restructuredtext'

lgr = logging.getLogger("datalad.runner.nonasyncrunner")

STDIN_FILENO = 0
STDOUT_FILENO = 1
STDERR_FILENO = 2


# A helper to type-safe retrieval of a Popen-fileno, if data exchange was
# requested.
def _get_fileno(active: bool,
                popen_std_x: Optional[IO]
                ) -> Optional[int]:
    if active:
        assert popen_std_x is not None
        return popen_std_x.fileno()
    return None


class _ResultGenerator(Generator):
    """
    Generator returned by run_command if the protocol class
    is a subclass of `datalad.runner.protocol.GeneratorMixIn`
    """
    class GeneratorState(enum.Enum):
        initialized = 0
        process_running = 1
        process_exited = 2
        connection_lost = 3
        waiting_for_process = 4
        exhausted = 5

    def __init__(self,
                 runner: ThreadedRunner,
                 result_queue: deque
                 ) -> None:

        super().__init__()
        self.runner = runner
        self.result_queue = result_queue
        self.return_code = None
        self.state = self.GeneratorState.process_running
        self.all_closed = False
        self.send_lock = threading.Lock()

    def _check_result(self):
        self.runner._check_result()

    def send(self, message):
        with self.send_lock:
            return self._locked_send(message)

    def _locked_send(self, message):
        if self.state == self.GeneratorState.initialized:
            if message is not None:
                raise RuntimeError(
                    f"sent non-None message {message!r} to initialized generator "
                )
            self.state = self.GeneratorState.process_running

        runner = self.runner

        if self.state == self.GeneratorState.process_running:
            # If we have elements in the result queue, return one
            while len(self.result_queue) == 0 and runner.should_continue():
                runner.process_queue()
            if len(self.result_queue) > 0:
                return self.result_queue.popleft()

            # The process must have exited
            # Let the protocol prepare the result. This has to be done after
            # the loop was left to ensure that all data from stdout and stderr
            # is processed.
            runner.protocol.process_exited()
            self.return_code = runner.process.poll()
            self._check_result()
            self.state = self.GeneratorState.process_exited

        if self.state == self.GeneratorState.process_exited:
            # The protocol might have added result in the
            # _prepare_result()- or in the process_exited()-
            # callback. Those are returned here.
            if len(self.result_queue) > 0:
                return self.result_queue.popleft()
            runner.ensure_stdin_stdout_stderr_closed()
            runner.protocol.connection_lost(None)   # TODO: check for exceptions
            runner.wait_for_threads()
            runner._set_process_exited()
            self.state = self.GeneratorState.connection_lost

        if self.state == self.GeneratorState.connection_lost:
            # Get all results that were enqueued in
            # state: GeneratorState.process_exited.
            if len(self.result_queue) > 0:
                return self.result_queue.popleft()
            self.state = self.GeneratorState.exhausted
            runner.owning_thread = None
            with runner.generator_condition:
                runner.generator = None
                runner.generator_condition.notify()

        if self.state == self.GeneratorState.exhausted:
            raise StopIteration(self.return_code)

        raise RuntimeError(f"unknown state: {self.state}")

    def throw(self, exception_type, value=None, trace_back=None):
        return Generator.throw(self, exception_type, value, trace_back)


[docs] class ThreadedRunner: """ A class the contains a naive implementation for concurrent sub-process execution. It uses `subprocess.Popen` and threads to read from stdout and stderr of the subprocess, and to write to stdin of the subprocess. All read data and timeouts are passed to a protocol instance, which can create the final result. """ # Interval in seconds after which we check that a subprocess # is still running. timeout_resolution = 0.2 def __init__(self, cmd: str | list, protocol_class: type[WitlessProtocol], stdin: int | IO | bytes | Queue[Optional[bytes]] | None, protocol_kwargs: Optional[dict] = None, timeout: Optional[float] = None, exception_on_error: bool = True, **popen_kwargs ): """ Parameters ---------- cmd : list or str Command to be executed, passed to `subprocess.Popen`. If cmd is a str, `subprocess.Popen will be called with `shell=True`. protocol : WitlessProtocol class or subclass which will be instantiated for managing communication with the subprocess. If the protocol is a subclass of `datalad.runner.protocol.GeneratorMixIn`, this function will return a `Generator` which yields whatever the protocol callback fed into `GeneratorMixIn.send_result()`. If the protocol is not a subclass of `datalad.runner.protocol.GeneratorMixIn`, the function will return the result created by the protocol method `_generate_result`. stdin : file-like, bytes, Queue, or None If stdin is a file-like, it will be directly used as stdin for the subprocess. The caller is responsible for writing to it and closing it. If stdin is a bytes, it will be fed to stdin of the subprocess. If all data is written, stdin will be closed. If stdin is a Queue, all elements (bytes) put into the Queue will be passed to stdin until None is read from the queue. If None is read, stdin of the subprocess is closed. If stdin is None, nothing will be sent to stdin of the subprocess. More precisely, `subprocess.Popen` will be called with `stdin=None`. protocol_kwargs : dict, optional Passed to the protocol class constructor. timeout : float, optional If a non-`None` timeout is specified, the `timeout`-method of the protocol will be called if: - stdin-write, stdout-read, or stderr-read time out. In this case the file descriptor will be given as argument to the timeout-method. If the timeout-method return `True`, the file descriptor will be closed. - process.wait() timeout: if waiting for process completion after stdin, stderr, and stdout takes longer than `timeout` seconds, the timeout-method will be called with the argument `None`. If it returns `True`, the process will be terminated. exception_on_error : bool, optional This argument is only interpreted if the protocol is a subclass of `GeneratorMixIn`. If it is `True` (default), a `CommandErrorException` is raised by the generator if the sub process exited with a return code not equal to zero. If the parameter is `False`, no exception is raised. In both cases the return code can be read from the attribute `return_code` of the generator. popen_kwargs : dict, optional Passed to `subprocess.Popen`, will typically be parameters supported by `subprocess.Popen`. Note that `bufsize`, `stdin`, `stdout`, `stderr`, and `shell` will be overwritten internally. """ self.cmd = cmd self.protocol_class = protocol_class self.stdin = stdin self.protocol_kwargs = protocol_kwargs or {} self.timeout = timeout self.exception_on_error = exception_on_error self.popen_kwargs = popen_kwargs self.catch_stdout = self.protocol_class.proc_out self.catch_stderr = self.protocol_class.proc_err self.write_stdin: bool = False self.stdin_queue: Optional[Queue] = None self.process_stdin_fileno: Optional[int] = None self.process_stdout_fileno: Optional[int] = None self.process_stderr_fileno: Optional[int] = None self.stderr_enqueueing_thread: Optional[ReadThread] = None self.stdout_enqueueing_thread: Optional[ReadThread] = None self.stdin_enqueueing_thread: Optional[WriteThread] = None self.process_waiting_thread: Optional[WaitThread] = None self.process_running: bool = False self.output_queue: Queue = Queue() self.process_removed: bool = False self.generator: Optional[_ResultGenerator] = None self.process: Optional[Popen[Any]] = None self.return_code: Optional[int] = None self.last_touched: dict[Optional[int], float] = dict() self.active_file_numbers: set[Optional[int]] = set() self.stall_check_interval = 10 self.initialization_lock = threading.Lock() self.generator_condition = threading.Condition() self.owning_thread: Optional[int] = None # Pure declarations self.protocol: WitlessProtocol self.fileno_mapping: dict[Optional[int], int] self.fileno_to_file: dict[Optional[int], Optional[IO]] self.file_to_fileno: dict[IO, int] self.result: dict def _check_result(self): if self.exception_on_error is True: if self.return_code not in (0, None): protocol = self.protocol decoded_output = { source: protocol.fd_infos[fileno][1].decode(protocol.encoding) for source, fileno in ( ("stdout", protocol.stdout_fileno), ("stderr", protocol.stderr_fileno)) if protocol.fd_infos[fileno][1] is not None } raise CommandError( cmd=self.cmd, code=self.return_code, stdout=decoded_output.get("stdout", None), stderr=decoded_output.get("stderr", None) )
[docs] def run(self) -> dict | _ResultGenerator: """ Run the command as specified in __init__. This method is not re-entrant. Furthermore, if the protocol is a subclass of ``GeneratorMixIn``, and the generator has not been exhausted, i.e. it has not raised `StopIteration`, this method should not be called again. If it is called again before the generator is exhausted, a ``RuntimeError`` is raised. In the non-generator case, a second caller will be suspended until the first caller has returned. Returns ------- Any If the protocol is not a subclass of ``GeneratorMixIn``, the result of protocol._prepare_result will be returned. Generator If the protocol is a subclass of ``GeneratorMixIn``, a Generator will be returned. This allows to use this method in constructs like:: for protocol_output in runner.run(): ... Where the iterator yields whatever protocol.pipe_data_received sends into the generator. If all output was yielded and the process has terminated, the generator will raise StopIteration(return_code), where return_code is the return code of the process. The return code of the process will also be stored in the "return_code"-attribute of the runner. So you could write:: gen = runner.run() for file_descriptor, data in gen: ... # get the return code of the process result = gen.return_code """ with self.initialization_lock: return self._locked_run()
def _locked_run(self) -> dict | _ResultGenerator: with self.generator_condition: if self.generator is not None: if self.owning_thread == threading.get_ident(): raise RuntimeError( "ThreadedRunner.run() was re-entered by already owning " f"thread {threading.get_ident()}. The execution is " f"still owned by thread {self.owning_thread}" ) self.generator_condition.wait() assert self.generator is None if isinstance(self.stdin, (int, IO, type(None))): # We will not write anything to stdin. If the caller passed a # file-like he can write to it from a different thread. self.write_stdin = False elif isinstance(self.stdin, bytes): # Establish a queue to write to the process and # enqueue the input that is already provided. self.write_stdin = True self.stdin_queue = Queue() self.stdin_queue.put(self.stdin) self.stdin_queue.put(None) elif isinstance(self.stdin, Queue): # Establish a queue to write to the process. self.write_stdin = True self.stdin_queue = self.stdin else: # We do not recognize the input class will and just pass is through # to Popen(). We assume that the caller handles any writing if # desired. self.write_stdin = False self.protocol = self.protocol_class(**self.protocol_kwargs) # The following command is generated internally by datalad # and trusted. Security check is therefore skipped. kwargs = { **self.popen_kwargs, **dict( bufsize=0, stdin=subprocess.PIPE if self.write_stdin else self.stdin, stdout=subprocess.PIPE if self.catch_stdout else None, stderr=subprocess.PIPE if self.catch_stderr else None, shell=True if isinstance(self.cmd, str) else False # nosec ) } if self.process is not None: raise RuntimeError(f"Process already running {self.process.pid}") self.return_code = None try: # The following command is generated internally by datalad # and trusted. Security check is therefore skipped. self.process = Popen(self.cmd, **kwargs) # nosec except OSError as e: if not on_windows and "argument list too long" in str(e).lower(): lgr.error( "Caught exception suggesting too large stack size limits. " "Hint: use 'ulimit -s' command to see current limit and " "e.g. 'ulimit -s 8192' to reduce it to avoid this " "exception. See " "https://github.com/datalad/datalad/issues/6106 for more " "information." ) raise self.process_running = True self.active_file_numbers.add(None) self.process_stdin_fileno = _get_fileno(self.write_stdin, self.process.stdin) self.process_stdout_fileno = _get_fileno(self.catch_stdout, self.process.stdout) self.process_stderr_fileno = _get_fileno(self.catch_stderr, self.process.stderr) # We pass process as transport-argument. It does not have the same # semantics as the asyncio-signature, but since it is only used in # WitlessProtocol, all necessary changes can be made there. self.protocol.connection_made(self.process) # Map the pipe file numbers to stdout and stderr file number, because # the latter are hardcoded in the protocol code self.fileno_mapping = { self.process_stdout_fileno: STDOUT_FILENO, self.process_stderr_fileno: STDERR_FILENO, self.process_stdin_fileno: STDIN_FILENO, } if None in self.fileno_mapping: self.fileno_mapping.pop(None) self.fileno_to_file = { self.process_stdout_fileno: self.process.stdout, self.process_stderr_fileno: self.process.stderr, self.process_stdin_fileno: self.process.stdin } if None in self.fileno_to_file: self.fileno_to_file.pop(None) self.file_to_fileno = { f: f.fileno() for f in ( self.process.stdout, self.process.stderr, self.process.stdin ) if f is not None } current_time = time.time() if self.timeout: self.last_touched[None] = current_time cmd_string = self.cmd if isinstance(self.cmd, str) else " ".join(self.cmd) if self.catch_stderr: if self.timeout: self.last_touched[self.process_stderr_fileno] = current_time self.active_file_numbers.add(self.process_stderr_fileno) self.last_touched[self.process_stderr_fileno] = current_time assert self.process.stderr is not None self.stderr_enqueueing_thread = ReadThread( identifier="STDERR: " + cmd_string[:20], signal_queues=[self.output_queue], user_info=self.process_stderr_fileno, source=self.process.stderr, destination_queue=self.output_queue) self.stderr_enqueueing_thread.start() if self.catch_stdout: if self.timeout: self.last_touched[self.process_stdout_fileno] = current_time self.active_file_numbers.add(self.process_stdout_fileno) self.last_touched[self.process_stdout_fileno] = current_time assert self.process.stdout is not None self.stdout_enqueueing_thread = ReadThread( identifier="STDOUT: " + cmd_string[:20], signal_queues=[self.output_queue], user_info=self.process_stdout_fileno, source=self.process.stdout, destination_queue=self.output_queue) self.stdout_enqueueing_thread.start() if self.write_stdin: # No timeouts for stdin self.active_file_numbers.add(self.process_stdin_fileno) assert self.stdin_queue is not None assert self.process.stdin is not None self.stdin_enqueueing_thread = WriteThread( identifier="STDIN: " + cmd_string[:20], user_info=self.process_stdin_fileno, signal_queues=[self.output_queue], source_queue=self.stdin_queue, destination=self.process.stdin) self.stdin_enqueueing_thread.start() self.process_waiting_thread = WaitThread( "process_waiter", [self.output_queue], self.process) self.process_waiting_thread.start() if isinstance(self.protocol, GeneratorMixIn): self.generator = _ResultGenerator( self, self.protocol.result_queue ) self.owning_thread = threading.get_ident() return self.generator return self.process_loop()
[docs] def process_loop(self) -> dict: # Process internal messages until no more active file descriptors # are present. This works because active file numbers are only # removed when an EOF is received in `self.process_queue`. while self.should_continue(): self.process_queue() # Let the protocol prepare the result. This has to be done after # the loop was left to ensure that all data from stdout and stderr # is processed. self.result = self.protocol._prepare_result() self.protocol.process_exited() # Ensure that all communication channels are closed. self.ensure_stdin_stdout_stderr_closed() self.protocol.connection_lost(None) # TODO: check exception self.wait_for_threads() self._set_process_exited() return self.result
def _handle_file_timeout(self, source): if self.protocol.timeout(self.fileno_mapping[source]) is True: self.remove_file_number(source) def _handle_process_timeout(self): if self.protocol.timeout(None) is True: self.ensure_stdin_stdout_stderr_closed() self.process.terminate() self.process.wait() self.remove_process() def _handle_source_timeout(self, source): if source is None: self._handle_process_timeout() else: self._handle_file_timeout(source) def _update_timeouts(self) -> bool: last_touched = list(self.last_touched.items()) new_times = dict() current_time = time.time() timeout_occurred = False for source, last_time in last_touched: if self.timeout is not None and current_time - last_time >= self.timeout: new_times[source] = current_time self._handle_source_timeout(source) timeout_occurred = True self.last_touched = { **self.last_touched, **new_times} return timeout_occurred
[docs] def process_timeouts(self) -> bool: """Check for timeouts This method checks whether a timeout occurred since it was called last. If a timeout occurred, the timeout handler is called. Returns: bool Return `True` if at least one timeout occurred, `False` if no timeout occurred. """ if self.timeout is not None: return self._update_timeouts() return False
[docs] def should_continue(self) -> bool: # Continue with queue processing if there is still a process or # monitored files, or if there are still elements in the output queue. return ( len(self.active_file_numbers) > 0 or not self.output_queue.empty() ) and not self.is_stalled()
[docs] def is_stalled(self) -> bool: # If all queue-filling threads have exited and the queue is empty, we # might have a stall condition. live_threads = [ thread.is_alive() for thread in ( self.stdout_enqueueing_thread, self.stderr_enqueueing_thread, self.process_waiting_thread, ) if thread is not None] return not any(live_threads) and self.output_queue.empty()
[docs] def check_for_stall(self) -> bool: if self.stall_check_interval == 0: self.stall_check_interval = 11 if self.is_stalled(): lgr.warning( "ThreadedRunner.process_queue(): stall detected") return True self.stall_check_interval -= 1 return False
def _set_process_exited(self): self.return_code = self.process.poll() self.process = None self.process_running = False
[docs] def process_queue(self): """ Get a single event from the queue or handle a timeout. This method might modify the set of active file numbers if a file-closed event is read from the output queue, or if a timeout-callback return True. """ data = None while True: # We do not need a user provided timeout here. If # self.timeout is None, no timeouts are reported anyway. # If self.timeout is not None, and any enqueuing (stdin) # or de-queuing (stdout, stderr) operation takes longer than # self.timeout, we will get a queue entry for that. # We still use a "system"-timeout, i.e. # `ThreadedRunner.process_check_interval`, to check whether the # process is still running. try: file_number, state, data = self.output_queue.get( timeout=ThreadedRunner.timeout_resolution) break except Empty: if self.check_for_stall() is True: return if self.process_timeouts(): return continue if state == IOState.process_exit: self.remove_process() return if self.write_stdin and file_number == self.process_stdin_fileno: # The only data-signal we expect from stdin thread # is None, indicating that the thread ended assert data is None self.remove_file_number(self.process_stdin_fileno) elif self.catch_stderr or self.catch_stdout: if data is None: # Received an EOF for stdout or stderr. self.remove_file_number(file_number) else: # Call the protocol handler for data assert isinstance(data, bytes) self.last_touched[file_number] = time.time() self.protocol.pipe_data_received( self.fileno_mapping[file_number], data)
[docs] def remove_process(self): if None not in self.active_file_numbers: # Might already be removed due to a timeout callback returning # True and subsequent removal of the process. return self.active_file_numbers.remove(None) if self.timeout: del self.last_touched[None] # Remove stdin from the active set because the process will # no longer consume input from stdin. This is done by enqueuing # None to the stdin queue. if self.write_stdin: self.stdin_queue.put(None) self.return_code = self.process.poll()
[docs] def remove_file_number(self, file_number: int): """ Remove a file number from the active set and from the timeout set. """ # TODO: check exception # Let the protocol know that the connection was lost. self.protocol.pipe_connection_lost( self.fileno_mapping[file_number], None) if file_number in self.active_file_numbers: # Remove the file number from the set of active numbers. self.active_file_numbers.remove(file_number) # If we are checking timeouts, remove the file number from # timeouts. if self.timeout and file_number in self.last_touched: del self.last_touched[file_number] _try_close(self.fileno_to_file[file_number])
[docs] def close_stdin(self): if self.stdin_queue: self.stdin_queue.put(None)
def _ensure_closed(self, file_objects): for file_object in file_objects: if file_object is not None: file_number = self.file_to_fileno.get(file_object, None) if file_number is not None: if self.timeout and file_number in self.last_touched: del self.last_touched[file_number] if file_number in self.active_file_numbers: self.active_file_numbers.remove(file_number) _try_close(file_object)
[docs] def ensure_stdin_stdout_stderr_closed(self): self.close_stdin() self._ensure_closed( ( self.process.stdin, self.process.stdout, self.process.stderr ) )
[docs] def ensure_stdout_stderr_closed(self): self._ensure_closed((self.process.stdout, self.process.stderr))
[docs] def wait_for_threads(self): for thread in (self.stderr_enqueueing_thread, self.stdout_enqueueing_thread, self.stdin_enqueueing_thread): if thread is not None: thread.request_exit()
[docs] def run_command(cmd: str | list, protocol: type[WitlessProtocol], stdin: int | IO | bytes | Queue[Optional[bytes]] | None, protocol_kwargs: Optional[dict] = None, timeout: Optional[float] = None, exception_on_error: bool = True, **popen_kwargs) -> dict | _ResultGenerator: """ Run a command in a subprocess this function delegates the execution to an instance of `ThreadedRunner`, please see `ThreadedRunner.__init__()` for a documentation of the parameters, and `ThreadedRunner.run()` for a documentation of the return values. """ runner = ThreadedRunner( cmd=cmd, protocol_class=protocol, stdin=stdin, protocol_kwargs=protocol_kwargs, timeout=timeout, exception_on_error=exception_on_error, **popen_kwargs, ) return runner.run()