In Spark SQL, how can I optimise non-equi-joins?

  • I have two data frames that I need to link using a non-equi-join (also known as an inequality join) with two join predicates. One dataframe is a histogram DataFrame[bin: bigint, lower_bound: double, upper_bound: double]

    The second dataframe is a set of observations DataFrame[id: bigint, observation: double]. I need to figure out which bin of my histogram each observation belongs in, as follows:

    observations_df.join(histogram_df, 
    (
    (observations_df.observation >= histogram_df.lower_bound) &
    (observations_df.observation < histogram_df.upper_bound)
    )
    )

    Basically, it's really sluggish, and I'm seeking recommendations on how to speed it up. Here is a sample code that exhibits the issue. When the number of rows in histogram df reaches sufficiently high (let's say a number of bins = 500000), it becomes extremely sluggish, and I'm confident it's because I'm doing a non-Equi-join. As this blog suggests, uneven join is not advised for spark join, however, if you run this code, play about with the value of a number of rows, starting with something small and gradually increasing until the sluggish performance becomes obvious.

    from pyspark.sql.functions import lit, col, lead
    from pyspark.sql.types import *
    from pyspark.sql import SparkSession
    from pyspark.sql.types import *
    from pyspark.sql.functions import rand
    from pyspark.sql import Window
    spark = SparkSession \
    .builder \
    .getOrCreate()

    number_of_bins = 500000

    bin_width = 1.0 / number_of_bins
    window = Window.orderBy('bin')
    histogram_df = spark.range(0, number_of_bins)\
    .withColumnRenamed('id', 'bin')\
    .withColumn('lower_bound', 0 + lit(bin_width) * col('bin'))\
    .select('bin', 'lower_bound', lead('lower_bound', 1, 1.0).over(window).alias('upper_bound'))
    observations_df = spark.range(0, 100000).withColumn('observation', rand())
    observations_df.join(histogram_df,
    (
    (observations_df.observation >= histogram_df.lower_bound) &
    (observations_df.observation < histogram_df.upper_bound)
    )
    ).groupBy('bin').count().head(15)
  • Why did you pick the SQL 2012 forum? That version doesn't deal with pyspark / dataframes

Viewing 2 posts - 1 through 1 (of 1 total)

You must be logged in to reply to this topic. Login to reply