Source code for pyspark_utils.utils

from typing import Dict, List, Optional

import pandas as pd
import pyspark
import pyspark.sql
import pyspark.sql.functions as F
from pyspark.sql.window import Window


[docs] def get_spark_session(app_name: str) -> pyspark.sql.SparkSession: """Recover appropriate SparkSession Args: app_name (str): Name of application Returns: pyspark.sql.SparkSession: A Spark session with name `app_name` """ session = ( pyspark.sql.SparkSession.builder.appName(app_name) .enableHiveSupport() .config("spark.sql.autoBroadcastJoinThreshold", -1) .getOrCreate() ) return session
[docs] def assert_cols_in_df( df: pyspark.sql.DataFrame, *columns: List[str], df_name: Optional[str] = "" ) -> None: """Assserts that all specified columns are present in specified dataframe. If not, displays an informative message. Args: df (pyspark.sql.DataFrame): pyspark dataframe df_name (Optional[str], optional): list of column names. Defaults to "". """ assert set(columns).issubset( df.columns ), f"Columns {' & '.join(set(columns[0]) - set(df.columns))} missing from dataframe {df_name}"
[docs] def assert_df_close(df1: pyspark.sql.DataFrame, df2: pyspark.sql.DataFrame, **kwargs) -> None: """Asserts that two dataframes are (almost) equal, even if the order of the columns is different. Args: df1 (pyspark.sql.DataFrame): _description_ df2 (pyspark.sql.DataFrame): _description_ kwargs (Optional[dict]): Any attribute of methods `pandas.testing.assert_frame_equal` """ df1_pd: pd.DataFrame = df1.toPandas() df2_pd: pd.DataFrame = df2.toPandas() cols1 = sorted(df1_pd.columns) cols2 = sorted(df2_pd.columns) pd.testing.assert_frame_equal( df1_pd[cols1].sort_values(by=cols1).reset_index(drop=True), df2_pd[cols2].sort_values(by=cols2).reset_index(drop=True), **kwargs, )
[docs] def with_columns( df: pyspark.sql.DataFrame, col_func_mapping: Dict[str, pyspark.sql.Column] ) -> pyspark.sql.DataFrame: """Use multiple 'withColumn' calls on a dataframe in a single command. This function is tail recursive. Args: df (pyspark.sql.DataFrame): pyspark dataframe col_func_mapping (Dict[str, pyspark.sql.Column]): dict to map each column name with the function to apply to it Returns: pyspark.sql.DataFrame: A pyspark dataframe identical to `df` but with additional columns. """ for col_name, col_func in col_func_mapping.items(): df = df.withColumn(col_name, col_func) return df
[docs] def keep_first_rows(df: pyspark.sql.DataFrame, partition_cols, order_cols): """Keep the first row of each group defined by `partition_cols` and `order_cols`. Args: df (pyspark.sql.DataFrame): pyspark dataframe partition_cols (_type_): _description_ order_cols (_type_): _description_ Returns: _type_: _description_ """ return ( df.withColumn( "rank", F.rank().over( Window.partitionBy(*partition_cols).orderBy(*order_cols + [F.rand(seed=1)]) ), # We add a random column in case there are ties (ties are broken arbitrarily) ) .filter(F.col("rank") == 1) .drop("rank") )