Source code for smv.datasetrepo
#
# This file is licensed under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import traceback
import pkgutil
import inspect
from error import SmvRuntimeError
from utils import for_name, smv_copy_array
"""Python implementations of IDataSetRepoPy4J and IDataSetRepoFactoryPy4J interfaces
"""
[docs]class DataSetRepoFactory(object):
def __init__(self, smvApp):
self.smvApp = smvApp
[docs] def createRepo(self):
try:
return DataSetRepo(self.smvApp)
except BaseException as e:
traceback.print_exc()
raise e
[docs] class Java:
implements = ['org.tresamigos.smv.IDataSetRepoFactoryPy4J']
[docs]class DataSetRepo(object):
def __init__(self, smvApp):
self.smvApp = smvApp
# Remove client modules from sys.modules to force reload of all client
# code in the new transaction
self._clear_sys_modules()
def _clear_sys_modules(self):
"""Clear all client modules from sys.modules
"""
for fqn in sys.modules.keys():
for stage_name in self.smvApp.stages:
if fqn == stage_name or fqn.startswith(stage_name + "."):
sys.modules.pop(fqn)
break
# Implementation of IDataSetRepoPy4J loadDataSet, which loads the dataset
# from the most recent source
[docs] def loadDataSet(self, fqn):
try:
ds = for_name(fqn)(self.smvApp)
# Python issue https://bugs.python.org/issue1218234
# need to invalidate inspect.linecache to make dataset hash work
srcfile = inspect.getsourcefile(ds.__class__)
if srcfile:
inspect.linecache.checkcache(srcfile)
return ds
except BaseException as e:
traceback.print_exc()
raise e
[docs] def dataSetsForStage(self, stageName):
try:
return self._moduleUrnsForStage(stageName, lambda obj: obj.IsSmvPyDataSet)
except BaseException as e:
traceback.print_exc()
raise e
[docs] def outputModsForStage(self, stageName):
return self.moduleUrnsForStage(stageName, lambda obj: obj.IsSmvPyModule and obj.IsSmvPyOutput)
def _moduleUrnsForStage(self, stageName, fn):
# `walk_packages` can generate AttributeError if the system has
# Gtk modules, which are not designed to use with reflection or
# introspection. Best action to take in this situation is probably
# to simply suppress the error.
def err(name): pass
# print("Error importing module %s" % name)
# t, v, tb = sys.exc_info()
# print("type is {0}, value is {1}".format(t, v))
buf = []
# import the stage and only walk the packages in the path of that stage, recursively
try:
stagemod = __import__(stageName)
except:
# may be a scala-only stage
pass
else:
for loader, name, is_pkg in pkgutil.walk_packages(stagemod.__path__, stagemod.__name__ + '.' , onerror=err):
# The additional "." is necessary to prevent false positive, e.g. stage_2.M1 matches stage
if name.startswith(stageName + ".") and not is_pkg:
pymod = __import__(name)
for c in name.split('.')[1:]:
pymod = getattr(pymod, c)
for n in dir(pymod):
obj = getattr(pymod, n)
try:
# Class should have an fqn which begins with the stageName.
# Each package will contain among other things all of
# the modules that were imported into it, and we need
# to exclude these (so that we only count each module once)
if fn(obj) and obj.fqn().startswith(name):
buf.append(obj.urn())
except AttributeError:
continue
return smv_copy_array(self.smvApp.sc, *buf)
[docs] def notFound(self, modUrn, msg):
raise ValueError("dataset [{0}] is not found in {1}: {2}".format(modUrn, self.__class__.__name__, msg))
[docs] class Java:
implements = ['org.tresamigos.smv.IDataSetRepoPy4J']