Alternate title: k-Nearest Neighbours (kNN) in PySpark

You can follow the story of what I wanted to do and how I did it. Or jump to the solution.

## Situation

• The PySpark MLlib (DataFrame-based, RDD-based) does not support kNN algorithm.
• The multiple SO questions (1, 2, 3) did not help.
• I did not get a chance to try the available open-source code samples (1, 2).
• There were two stable-ish solutions available:
1. Use Spotify’s library called Annoy.
2. Use Scikit-learn’s implementation of kNN.
• Both methods only use a single node. So, the benefit of Spark’s distributed processing goes out of the window.
• That’s a deal breaker because I have a large data set (~20+ mill records) with long vectors.
• I picked Annoy because I found it first. I discuss at the end why Scikit could be more performant.

The task is to parallelise the Annoy code across multiple nodes of the Spark cluster.

## Solution

• A hint about the solution is present in this SO Answer.
• For both Annoy and Scikit, the approach is as follows:
1. Build the index or fit the model on a single node. Nothing is distributed here.
2. Broadcast the index (or model) across the cluster to find the nearest neighbours of a given vector.

### Building the index

• I first tried to use spark-annoy. It is in Scala. The benefit of this library was that we could build the index in a distributed manner. Unfortunately, I could not figure it out.
• The default was to use the iterative approach of building the index on a single node.
• The following is the method to build the index:
• The for loop makes it a long-running process if the data is large. Sadly, it is unavoidable.
• Note that the type of vectors is pd.Series. I used pandas to get the goodness of indexes. It can be a list or any other iterable. It should be an iterable irrespective of its type. That means either of the following:
1. Run .collect() on the Spark DataFrame;
2. Turn the spark DataFrame into a pandas DataFrame.
• That will bring all the data to a single node. So, it can potentially lead to OOM error.

### Finding Nearest Neighbours

• We can parallelise this step.
• We have to broadcast the Annoy index across all the nodes of the Spark cluster.
• The Annoy indexes are memory mapped.

It also creates large read-only file-based data structures that are mmapped into memory so that many processes may share the same data.

• It will fail if we broadcast it using sc.broadcast(t). This SO answer discusses this issue.
• The solution: write the index to a file and send the file to all the workers to load.
• Use sc.addFile() to send the file to the workers.
• Use SparkFiles.get() to get the file path and load it in the worker node.
• Here is the method to load the Annoy index in the worker nodes:
• I call the below method to get the nearest neighbours of a set of index ids:
• The item_batch is the list of Annoy index ids.
• The function returns the list of (annoy_item_index, [(rank_1, sim_item_1), ..., (rank_n, sim_item_n)]).
• For example, here is one such item from the list:
(248, [[0, 248], [9, 284764], [3, 86148], [6, 265812], [7, 508155], [2, 48388], [10, 58786], [1, 154653], [5, 364419], [4, 4444], [8, 89955]])

• To validate the function, the most similar item corresponding to the query item should be itself. In the above example, the query index 248 ranks 0 in the top similar items.
• To get the nearest neighbours by vectors, you pass the vectors in the item_batch and use the get_nns_by_vector method.
• I keep my find_neighbours method generic for the following parameters:
• Annoy index file name: I can have any name based on my use case.
• Top n items: number of top items I want to retrieve.
• Dim: The dimension of the vector can vary depending on the various ML techniques (LDA, DL, etc.)
• We have only written the method to find the nearest neighbours. How do we call it in a distributed manner? This SO answer answers that too.
• The answer is mapPartitions. This method will apply the passed function to each RDD partition. Go through the answers of this SO question to know more in detail.
• I will pass find_neighbours to the mapPartitions, and it will return an RDD with the nearest neighbours list.
• But my find_neighbours implementation takes four parameters, and there is no way of sending **args inside the mapPartitions.

• I use the inbuilt python partial() function from the functools module.

The partial() is used for partial function application which “freezes” some portion of a function’s arguments and/or keywords resulting in a new object with a simplified signature.

• Here is how my final function looks:
• The build_save_annoy_index() method builds the index, saves it to a file, and returns the file path.
• Finally, we see the use of sc.addFile(index_file).
• The find_neighbours_() is the partial function. We froze the index_file, top_n, and dim. This function now only expects a single RDD as input. And this is what we wanted for the mapPartitions() method.

### Saving Results

• I take the similar_items list and convert it into a pandas DataFrame.
• Map ALL the Annoy index ids with the actual item ids. That includes all the index ids of the top-n similar items list.
• Convert the pandas DataFrame to a PySpark DataFrame.
• Save the PySpark DataFrame into a delta table.

## Result

• I was able to parallelise the kNN search based on Annoy using mapPartitions.
• On ~500k records, the run time was down from 8 minutes to 2 minutes.
• On ~10 million records (with an index built from ~500k records), the run time was ~1 hour.
• On ~10 million records (with an index built from ~10 million records), I got an OOM error. 🥲

## What’s Next

• Find the reason it is going OOM.
• Converting the pandas DataFrame to PySpark DataFrame is expensive. I want to explore if I can directly go from pandas DataFrame to the delta. Ref: 1, 2.
• Since the Scikit has vectorised training and inferencing, its kNN would likely be faster. This post shows how to do it. I would probably replace the map with mapPartitions.