Source code for smv.datasetmgr

#
# This file is licensed under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""DataSetMgr entry class
This module provides the python entry point to DataSetMgr on scala side
"""
import smv

from smv.utils import smv_copy_array, scala_seq_to_list, list_distinct, infer_full_name_from_part
from smv.datasetresolver import DataSetResolver

[docs]class DataSetMgr(object): """The Python representation of DataSetMgr. """ def __init__(self, _jvm, smvconfig): self._jvm = _jvm self.smvconfig = smvconfig self.dsRepoFactories = [] from py4j.java_gateway import java_import java_import(self._jvm, "org.tresamigos.smv.python.SmvPythonHelper") java_import(self._jvm, "org.tresamigos.smv.DataSetRepoFactoryPython") self.helper = self._jvm.SmvPythonHelper
[docs] def stages(self): return self.smvconfig.stage_names()
[docs] def tx(self): """Create a TXContext for multiple places, avoid the long TXContext line """ return TXContext(self._jvm, self.dsRepoFactories, self.stages())
[docs] def load(self, *fqns): """Load SmvGenericModules for specified FQNs Args: *fqns (str): list of FQNs as strings Returns: list(SmvGenericModules): list of Scala SmvGenericModules (j_ds) """ with self.tx() as tx: return tx.load(fqns)
[docs] def inferDS(self, *partial_names): """Return DSs from a list of partial names Args: *partial_names (str): list of partial names Returns: list(SmvGenericModules): list of SmvGenericModules """ with self.tx() as tx: return tx.inferDS(partial_names)
[docs] def inferFqn(self, partial_name): """Return FQN string from partial name """ with self.tx() as tx: return tx._inferFqn([partial_name])[0]
[docs] def register(self, repo_factory): """Register python repo factory """ self.dsRepoFactories.append(repo_factory)
[docs] def allDataSets(self): """Return all the SmvGenericModules in the app """ with self.tx() as tx: return tx.allDataSets()
[docs] def modulesToRun(self, modPartialNames, stageNames, allMods): """Return a modules need to run Combine specified modules, (-m), stages, (-s) and if (--run-app) specified, all output modules """ with self.tx() as tx: named_mods = tx.inferDS(modPartialNames) stage_mods = tx.outputModulesForStage(stageNames) app_mods = tx.allOutputModules() if allMods else [] res = [] res.extend(named_mods) res.extend(stage_mods) res.extend(app_mods) # Need to perserve the ordering return list_distinct(res)
[docs]class TXContext(object): """Create a TX context for "with tx() as tx" syntax """ def __init__(self, _jvm, resourceFactories, stages): self._jvm = _jvm self.resourceFactories = resourceFactories self.stages = stages def __enter__(self): return TX(self._jvm, self.resourceFactories, self.stages) def __exit__(self, type, value, traceback): pass
[docs]class TX(object): """Abstraction of the transaction boundary for loading SmvGenericModules. A TX object * will instantiate a set of repos when itself instantiated and will * reuse the same repos for all queries. This means that each new TX object will * reload the SmvGenericModules from source **once** during its lifetime. NOTE: Once a new TX is created, the well-formedness of the SmvGenericModules provided by the previous TX is not guaranteed. Particularly it may become impossible to run modules from the previous TX. """ def __init__(self, _jvm, resourceFactories, stages): self.repos = [rf.createRepo() for rf in resourceFactories] self.stages = stages self.resolver = DataSetResolver(self.repos[0]) self.log = smv.logger
[docs] def load(self, fqns): return self.resolver.loadDataSet(fqns)
[docs] def inferDS(self, partial_names): return self.load(self._inferFqn(partial_names))
[docs] def allDataSets(self): return self.load(self._allFqns())
[docs] def allOutputModules(self): return self._filterOutput(self.allDataSets())
[docs] def outputModulesForStage(self, stageNames): return self._filterOutput(self._dsForStage(stageNames))
def _dsForStage(self, stageNames): return self.load(self._fqnsForStage(stageNames)) def _fqnsForStage(self, stageNames): return [u for repo in self.repos for s in stageNames for u in repo.dataSetsForStage(s) ] def _allFqns(self): if (len(self.stages) == 0): log.warn("No stage names configured. Unable to discover any modules.") return self._fqnsForStage(self.stages) def _inferFqn(self, partial_names): def fqn_str(pn): return infer_full_name_from_part( self._allFqns(), pn ) return [fqn_str(pn) for pn in partial_names] def _filterOutput(self, dss): return [ds for ds in dss if ds.isSmvOutput()]