Source code for smv.smvapp

#
# This file is licensed under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""SmvPy entry class and ``singleton``smvApp
This module provides the main SMV Python entry point ``SmvPy`` class and a singleton `smvApp`.
It is equivalent to ``SmvApp`` on Scala side
"""

from py4j.java_gateway import java_import, JavaObject

from pyspark import SparkContext
from pyspark.sql import HiveContext, DataFrame
from utils import for_name, smv_copy_array, check_socket
from error import SmvRuntimeError

import inspect
import pkgutil
import os
import re
import sys
import traceback

if sys.version >= '3':
    basestring = unicode = str
    long = int
    from io import StringIO
    from importlib import reload
else:
    from cStringIO import StringIO

[docs]class SmvApp(object): """The Python representation of SMV. Its singleton instance is created later in the containing module and is named smvApp Adds `java_imports` to the namespace in the JVM gateway in SparkContext (in pyspark). It also creates an instance of SmvPyClient. """ # Singleton instance of SmvApp _instance = None @classmethod
[docs] def getInstance(cls): if cls._instance is None: raise SmvRuntimeError("An instance of SmvApp has not been created") else: return cls._instance
@classmethod
[docs] def createInstance(cls, arglist, _sc = None, _sqlContext = None): """Create singleton instance. Also returns the instance. """ cls._instance = cls(arglist, _sc, _sqlContext) return cls._instance
def __init__(self, arglist, _sc = None, _sqlContext = None): sc = SparkContext() if _sc is None else _sc sqlContext = HiveContext(sc) if _sqlContext is None else _sqlContext self.prepend_source("src/main/python") sc.setLogLevel("ERROR") self.sqlContext = sqlContext self.sc = sc self._jvm = sc._jvm from py4j.java_gateway import java_import java_import(self._jvm, "org.tresamigos.smv.ColumnHelper") java_import(self._jvm, "org.tresamigos.smv.SmvDFHelper") java_import(self._jvm, "org.tresamigos.smv.dqm.*") java_import(self._jvm, "org.tresamigos.smv.python.SmvPythonHelper") self.j_smvPyClient = self.create_smv_pyclient(arglist) # shortcut is meant for internal use only self.j_smvApp = self.j_smvPyClient.j_smvApp() # issue #429 set application name from smv config sc._conf.setAppName(self.appName()) # user may choose a port for the callback server gw = sc._gateway cbsp = self.j_smvPyClient.callbackServerPort() cbs_port = cbsp.get() if cbsp.isDefined() else gw._python_proxy_port # check wither the port is in-use or not. Try 10 times, if all fail, error out check_counter = 0 while(not check_socket(cbs_port) and check_counter < 10): cbs_port += 1 check_counter += 1 if (not check_socket(cbs_port)): raise SmvRuntimeError("Start Python callback server failed. Port {0}-{1} are all in use".format(cbs_port - check_counter, cbs_port)) # this was a workaround for py4j 0.8.2.1, shipped with spark # 1.5.x, to prevent the callback server from hanging the # python, and hence the java, process from pyspark.streaming.context import _daemonize_callback_server _daemonize_callback_server() if "_callback_server" not in gw.__dict__ or gw._callback_server is None: print("Starting Py4j callback server on port {0}".format(cbs_port)) gw._shutdown_callback_server() # in case another has already started gw._start_callback_server(cbs_port) gw._python_proxy_port = gw._callback_server.port # get the GatewayServer object in JVM by ID jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client) # update the port of CallbackClient with real port gw.jvm.SmvPythonHelper.updatePythonGatewayPort(jgws, gw._python_proxy_port) self.j_smvPyClient.registerRepoFactory('Python', DataSetRepoFactory(self))
[docs] def appName(self): return self.j_smvApp.smvConfig().appName()
[docs] def create_smv_pyclient(self, arglist): ''' return a smvPyClient instance ''' # convert python arglist to java String array java_args = smv_copy_array(self.sc, *arglist) return self._jvm.org.tresamigos.smv.python.SmvPyClientFactory.init(java_args, self.sqlContext._ssql_ctx)
[docs] def get_graph_json(self): """Generate a json string representing the dependency graph. TODO: need to add a stageName parameter to limit it to a single stage. """ return self.j_smvApp.generateAllGraphJSON()
[docs] def runModule(self, urn): """Runs either a Scala or a Python SmvModule by its Fully Qualified Name(fqn) """ jdf = self.j_smvPyClient.runModule(urn) return DataFrame(jdf, self.sqlContext)
[docs] def runModuleByName(self, name): jdf = self.j_smvApp.runModuleByName(name) return DataFrame(jdf, self.sqlContext)
[docs] def urn2fqn(self, urnOrFqn): """Extracts the SMV module FQN portion from its URN; if it's already an FQN return it unchanged""" return self.j_smvPyClient.urn2fqn(urnOrFqn)
[docs] def outputDir(self): return self.j_smvPyClient.outputDir()
[docs] def scalaOption(self, val): """Returns a Scala Option containing the value""" return self._jvm.scala.Option.apply(val)
[docs] def createDF(self, schema, data = None): return DataFrame(self.j_smvPyClient.dfFrom(schema, data), self.sqlContext)
def _mkCsvAttr(self, delimiter=',', quotechar='"', hasHeader=False): """Factory method for creating instances of Scala case class CsvAttributes""" return self._jvm.org.tresamigos.smv.CsvAttributes(delimiter, quotechar, hasHeader)
[docs] def defaultCsvWithHeader(self): return self._mkCsvAttr(hasHeader=True)
[docs] def defaultTsv(self): return self._mkCsvAttr(delimiter='\t')
[docs] def defaultTsvWithHeader(self): return self._mkCsvAttr(delimier='\t', hasHeader=True)
[docs] def prepend_source(self,source_dir): # Source must be added to front of path to make sure it is found first codePath = os.path.abspath(source_dir) sys.path.insert(1, codePath)
[docs] def run(self): self.j_smvApp.run()
[docs]class DataSetRepoFactory(object): def __init__(self, smvApp): self.smvApp = smvApp
[docs] def createRepo(self): try: return DataSetRepo(self.smvApp) except BaseException as e: traceback.print_exc() raise e
[docs] class Java: implements = ['org.tresamigos.smv.IDataSetRepoFactoryPy4J']
[docs]class DataSetRepo(object): def __init__(self, smvApp): self.smvApp = smvApp # Implementation of IDataSetRepoPy4J loadDataSet, which loads the dataset # from the most recent source
[docs] def loadDataSet(self, fqn): lastdot = fqn.rfind('.') try: if sys.modules.has_key(fqn[:lastdot]): # reload the module if it has already been imported return self._reload(fqn) else: # otherwise import the module return self._load(fqn) except BaseException as e: traceback.print_exc() raise e
# Import the module (Python module, not SMV module) containing the dataset # and return the dataset def _load(self, fqn): return for_name(fqn)(self.smvApp) # Reload the module containing the dataset from the most recent source # and invalidate the linecache def _reload(self, fqn): lastdot = fqn.rfind('.') pmod = reload(sys.modules[fqn[:lastdot]]) klass = getattr(pmod, fqn[lastdot+1:]) ds = klass(self.smvApp) # Python issue https://bugs.python.org/issue1218234 # need to invalidate inspect.linecache to make dataset hash work srcfile = inspect.getsourcefile(ds.__class__) if srcfile: inspect.linecache.checkcache(srcfile) return ds
[docs] def dataSetsForStage(self, stageName): try: return self._moduleUrnsForStage(stageName, lambda obj: obj.IsSmvPyDataSet) except BaseException as e: traceback.print_exc() raise e
[docs] def outputModsForStage(self, stageName): return self.moduleUrnsForStage(stageName, lambda obj: obj.IsSmvPyModule and obj.IsSmvPyOutput)
def _moduleUrnsForStage(self, stageName, fn): # `walk_packages` can generate AttributeError if the system has # Gtk modules, which are not designed to use with reflection or # introspection. Best action to take in this situation is probably # to simply suppress the error. def err(name): pass # print("Error importing module %s" % name) # t, v, tb = sys.exc_info() # print("type is {0}, value is {1}".format(t, v)) buf = [] # import the stage and only walk the packages in the path of that stage, recursively try: stagemod = __import__(stageName) except: # may be a scala-only stage pass else: for loader, name, is_pkg in pkgutil.walk_packages(stagemod.__path__, stagemod.__name__ + '.' , onerror=err): # The additional "." is necessary to prevent false positive, e.g. stage_2.M1 matches stage if name.startswith(stageName + ".") and not is_pkg: pymod = __import__(name) for c in name.split('.')[1:]: pymod = getattr(pymod, c) for n in dir(pymod): obj = getattr(pymod, n) try: # Class should have an fqn which begins with the stageName. # Each package will contain among other things all of # the modules that were imported into it, and we need # to exclude these (so that we only count each module once) if fn(obj) and obj.fqn().startswith(name): buf.append(obj.urn()) except AttributeError: continue return smv_copy_array(self.smvApp.sc, *buf)
[docs] def notFound(self, modUrn, msg): raise ValueError("dataset [{0}] is not found in {1}: {2}".format(modUrn, self.__class__.__name__, msg))
[docs] class Java: implements = ['org.tresamigos.smv.IDataSetRepoPy4J']