Alternate title: k-Nearest Neighbours (kNN) in PySpark
You can follow the story of what I wanted to do and how I did it. Or jump to the solution.
- The PySpark MLlib (DataFrame-based, RDD-based) does not support kNN algorithm.
- The multiple SO questions (1, 2, 3) did not help.
- I did not get a chance to try the available open-source code samples (1, 2).
- There were two stable-ish solutions available:
- Both methods only use a single node. So, the benefit of Spark’s distributed processing goes out of the window.
- That’s a deal breaker because I have a large data set (~20+ mill records) with long vectors.
- I picked Annoy because I found it first. I discuss at the end why Scikit could be more performant.
The task is to parallelise the Annoy code across multiple nodes of the Spark cluster.
- A hint about the solution is present in this SO Answer.
- For both Annoy and Scikit, the approach is as follows:
- Build the index or fit the model on a single node. Nothing is distributed here.
- Broadcast the index (or model) across the cluster to find the nearest neighbours of a given vector.
Building the index
- I first tried to use spark-annoy. It is in Scala. The benefit of this library was that we could build the index in a distributed manner. Unfortunately, I could not figure it out.
- The default was to use the iterative approach of building the index on a single node.
- The following is the method to build the index:
1 2 3 4 5 6 7 8 9 import pandas as pd from annoy import AnnoyIndex def build_annoy_index(vectors: pd.Series, dim: int, num_trees: int = 100): t = AnnoyIndex(dim, metric='angular') for index, vector in vectors.items(): t.add_item(index, vector) t.build(num_trees) return t
forloop makes it a long-running process if the data is large. Sadly, it is unavoidable.
- Note that the type of
pd.Series. I used pandas to get the goodness of indexes. It can be a
listor any other iterable. It should be an iterable irrespective of its type. That means either of the following:
.collect()on the Spark DataFrame;
- Turn the spark DataFrame into a pandas DataFrame.
- That will bring all the data to a single node. So, it can potentially lead to OOM error.
Finding Nearest Neighbours
- We can parallelise this step.
- We have to broadcast the Annoy index across all the nodes of the Spark cluster.
The Annoy indexes are memory mapped.
It also creates large read-only file-based data structures that are mmapped into memory so that many processes may share the same data.
- It will fail if we broadcast it using
sc.broadcast(t). This SO answer discusses this issue.
- The solution: write the index to a file and send the file to all the workers to load.
sc.addFile()to send the file to the workers.
SparkFiles.get()to get the file path and load it in the worker node.
- Here is the method to load the Annoy index in the worker nodes:
1 2 3 4 5 6 def load_annoy_index(index_file: str, dim: int): from annoy import AnnoyIndex index = AnnoyIndex(dim, metric='angular') index.load(SparkFiles.get(index_file)) return index
- I call the below method to get the nearest neighbours of a set of index ids:
1 2 3 4 5 6 7 8 9 def find_neighbours(index_file, top_n, dim, item_batch): index = load_annoy_index(index_file, dim) # get similar items sim_items =  for item in item_batch: top_n_items = index.get_nns_by_item(i=item, n=top_n) sim_items.append((item, list(enumerate(top_n_items)))) return sim_items
item_batchis the list of Annoy index ids.
- The function returns the list of
(annoy_item_index, [(rank_1, sim_item_1), ..., (rank_n, sim_item_n)]).
- For example, here is one such item from the list:
(248, [[0, 248], [9, 284764], [3, 86148], [6, 265812], [7, 508155], [2, 48388], [10, 58786], [1, 154653], [5, 364419], [4, 4444], [8, 89955]])
- To validate the function, the most similar item corresponding to the query item should be itself. In the above example, the query index
0in the top similar items.
- To get the nearest neighbours by vectors, you pass the vectors in the
item_batchand use the
- I keep my
find_neighboursmethod generic for the following parameters:
- Annoy index file name: I can have any name based on my use case.
- Top n items: number of top items I want to retrieve.
- Dim: The dimension of the vector can vary depending on the various ML techniques (LDA, DL, etc.)
- We have only written the method to find the nearest neighbours. How do we call it in a distributed manner? This SO answer answers that too.
- The answer is
mapPartitions. This method will apply the passed function to each RDD partition. Go through the answers of this SO question to know more in detail.
- I will pass
mapPartitions, and it will return an RDD with the nearest neighbours list.
find_neighboursimplementation takes four parameters, and there is no way of sending
I use the inbuilt python
partial()function from the
partial()is used for partial function application which “freezes” some portion of a function’s arguments and/or keywords resulting in a new object with a simplified signature.
- Here is how my final function looks:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 def build_index_get_similar_items(vector_df, index_file, top_n): vectors = vector_df.vector sparkvector_ids = sc.parallelize(vector_df.index.values) # build and save index index_file = build_save_annoy_index(vectors, dim, index_file, BASE_DIR) print(index_file) # add index file to the driver files sc.addFile(index_file) # get similar items find_neighbours_ = partial(find_neighbours, index_file, top_n, dim) similar_items = sparkvector_ids.mapPartitions(find_neighbours_).collect() return similar_items
build_save_annoy_index()method builds the index, saves it to a file, and returns the file path.
- Finally, we see the use of
find_neighbours_()is the partial function. We froze the
dim. This function now only expects a single RDD as input. And this is what we wanted for the
- I take the
similar_itemslist and convert it into a pandas DataFrame.
- Map ALL the Annoy index ids with the actual item ids. That includes all the index ids of the top-n similar items list.
- Convert the pandas DataFrame to a PySpark DataFrame.
- Save the PySpark DataFrame into a delta table.
- I was able to parallelise the kNN search based on Annoy using
- On ~500k records, the run time was down from 8 minutes to 2 minutes.
- On ~10 million records (with an index built from ~500k records), the run time was ~1 hour.
- On ~10 million records (with an index built from ~10 million records), I got an OOM error. 🥲
- Find the reason it is going OOM.
- Converting the pandas DataFrame to PySpark DataFrame is expensive. I want to explore if I can directly go from pandas DataFrame to the delta. Ref: 1, 2.
- Since the Scikit has vectorised training and inferencing, its kNN would likely be faster. This post shows how to do it. I would probably replace the