Source code for ai_flow.operators.bash

#
# Copyright 2022 The AI Flow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
#
import os
import signal
import threading
from subprocess import Popen, STDOUT, TimeoutExpired, PIPE
from tempfile import TemporaryDirectory
from typing import Optional

from ai_flow.common.exception.exceptions import AIFlowException
from ai_flow.common.util.thread_utils import StoppableThread
from ai_flow.model.context import Context
from ai_flow.model.operator import AIFlowOperator


[docs]class BashOperator(AIFlowOperator): def __init__(self, name: str, bash_command: str, **kwargs): super().__init__(name, **kwargs) self.bash_command = bash_command self.sub_process = None self.log_reader = None
[docs] def start(self, context: Context): with TemporaryDirectory(prefix='aiflow_tmp') as tmp_dir: def pre_exec(): # Restore default signal disposition and invoke setsid for sig in ('SIGPIPE', 'SIGXFZ', 'SIGXFSZ'): if hasattr(signal, sig): signal.signal(getattr(signal, sig), signal.SIG_DFL) os.setsid() self.log.info('Running command: %s', self.bash_command) self.sub_process = Popen( ['bash', "-c", self.bash_command], stdout=PIPE, stderr=STDOUT, cwd=tmp_dir, preexec_fn=pre_exec, ) self.log_reader = StoppableThread(target=self._read_output) self.log_reader.start()
[docs] def stop(self, context: Context): self.log.info('Sending SIGTERM signal to bash process group') try: if self.sub_process and hasattr(self.sub_process, 'pid'): os.killpg(os.getpgid(self.sub_process.pid), signal.SIGTERM) finally: # Need to call sub_process.wait() to avoid becoming zombie processes self.sub_process.wait() self.log_reader.stop()
[docs] def await_termination(self, context: Context, timeout: Optional[int] = None): try: self.sub_process.wait(timeout=timeout) self.log.info('Command exited with return code %s', self.sub_process.returncode) if self.sub_process.returncode != 0: raise AIFlowException('Bash command failed. The command returned a non-zero exit code.') except TimeoutExpired: self.log.error("Timeout to wait bash operator to be finished in {} seconds".format(timeout)) raise finally: self.log_reader.stop()
def _read_output(self): self.log.info('Output:') for raw_line in iter(self.sub_process.stdout.readline, b''): if not threading.current_thread().stopped(): line = raw_line.decode('utf-8').rstrip() self.log.info("%s", line)