Source code for smv.smvschema
#
# This file is licensed under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import pyspark.sql.types as sql_types
from smv.smvapp import SmvApp
[docs]class SmvSchema(object):
"""The Python representation of SmvSchema scala class.
Most of the work is still being done on the scala side and this is just a pass through.
"""
def __init__(self, j_smv_schema):
self.j_smv_schema = j_smv_schema
self.spark_schema = self._toStructType()
[docs] @staticmethod
def discover(csv_path, csvAttributes, n=100000):
"""Discover schema from CSV file with given csv attributes"""
smvApp = SmvApp.getInstance()
j_smv_schema = smvApp.discoverSchemaAsSmvSchema(csv_path, csvAttributes, n)
return SmvSchema(j_smv_schema)
[docs] @staticmethod
def fromFile(schema_file):
smvApp = SmvApp.getInstance()
j_smv_schema = smvApp.smvSchemaObj.fromFile(smvApp.j_smvApp.sc(), schema_file)
return SmvSchema(j_smv_schema)
[docs] @staticmethod
def fromString(schema_str):
smvApp = SmvApp.getInstance()
j_smv_schema = smvApp.smvSchemaObj.fromString(schema_str)
return SmvSchema(j_smv_schema)
[docs] def toValue(self, i, str_val):
"""convert the string value to native value based on type defined in schema.
For example, if first column was of type Int, then call of `toValue(0, "55")`
would return integer 55.
"""
val = self.j_smv_schema.toValue(i, str_val)
j_type = self.spark_schema.fields[i].dataType.typeName()
if j_type == sql_types.DateType.typeName():
val = datetime.date(1900 + val.getYear(), val.getMonth()+1, val.getDate())
return val
[docs] def saveToLocalFile(self, schema_file):
"""Save schema to local (driver) file"""
self.j_smv_schema.saveToLocalFile(schema_file)
def _scala_to_python_field_type(self, scala_field_type):
"""create a python FieldType from the scala field type"""
col_name = str(scala_field_type.name())
col_type_name = str(scala_field_type.dataType())
# map string "IntegerType" to actual class IntegerType
col_type_class = getattr(sql_types, col_type_name)
return sql_types.StructField(col_name, col_type_class())
def _toStructType(self):
"""return equivalent Spark schema (StructType) from this smv schema"""
# ss is the raw scala spark schema (Scala StructType). This has no
# iterator defined on the python side, so we use old school for loop.
ss = self.j_smv_schema.toStructType()
spark_schema = sql_types.StructType()
for i in range(ss.length()):
# use "apply" to get the nth StructField item in StructType
ft = self._scala_to_python_field_type(ss.apply(i))
spark_schema = spark_schema.add(ft)
return spark_schema