Source code for ai_flow.operators.spark.spark_sql

#
# Copyright 2022 The AI Flow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
#
import os
import subprocess
from typing import Optional

from ai_flow.common.env import expand_env_var
from ai_flow.common.exception.exceptions import AIFlowException
from ai_flow.model.context import Context
from ai_flow.model.operator import AIFlowOperator


[docs]class SparkSqlOperator(AIFlowOperator): """ SparkSqlOperator only supports client mode for now. """ def __init__(self, name: str, sql: str, master: str = 'yarn', application_name: str = None, executable_path: Optional[str] = None, **kwargs): super().__init__(name, **kwargs) self._application_name = application_name or f'spark_sql_task_{name}' self._master = master self._executable_path = executable_path self._sql = sql self._spark_sql_cmd = None self._process = None
[docs] def start(self, context: Context): self._spark_sql_cmd = self._build_spark_sql_command() kwargs = {} self._process = subprocess.Popen( self._spark_sql_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, bufsize=-1, universal_newlines=True, **kwargs, )
[docs] def await_termination(self, context: Context, timeout: Optional[int] = None): for line in iter(self._process.stdout): # type: ignore self.log.info(line) ret_code = self._process.wait() if ret_code: raise AIFlowException( "Cannot execute '{}' on {}. Process exit code: {}.".format( self._sql, self._master, ret_code ) )
[docs] def stop(self, context: Context): if self._process and self._process.poll() is None: self.log.info("Killing the Spark-Sql job") self._process.kill()
def _get_executable_path(self): if self._executable_path: spark_sql = self._executable_path else: spark_sql = expand_env_var('${SPARK_HOME}/bin/spark-sql') if not os.path.exists(spark_sql): raise AIFlowException(f'Cannot find spark-sql at {spark_sql}') return spark_sql def _build_spark_sql_command(self): spark_sql_cmd = [self._get_executable_path()] if self._application_name: spark_sql_cmd += ["--name", self._application_name] if self._sql: sql = self._sql.strip() if sql.endswith(".sql") or sql.endswith(".hql"): spark_sql_cmd += ["-f", sql] else: spark_sql_cmd += ["-e", sql] self.log.debug("Spark-Sql cmd: %s", spark_sql_cmd) return spark_sql_cmd