Interview Case Study #1: The Statistics Of KNN Parameter Optimization
By Dae Won Kim, President and Yelp Dataset Challenge Team Project Manager
This semester undoubtedly marks a period of explosive growth for Cornell Data Science. Over 150 aspiring students applied to our core teams and a total of 120 people have signed up for credit to our Data Science Training Program. With the stark increase in the number of applications, our project managers have fought hard to maintain a high level of professionalism and technical rigor in the selection process.
A key component of this year’s selection process was the technical interview. We’d like to share the questions asked in these case studies to offer a glimpse of what we do at Cornell Data Science.
When interviewing students for the Yelp Dataset Challenge team, I focused on a specific algorithm: k-Nearest Neighbor (kNN). The following case study dives into the intuition behind kNN and the potential problems that might come up when implementing it.
Background
What is KNN?
kNN is a very simple, and often overlooked, machine learning algorithm that can be used as both a regressor and classifier. In the case of classification, kNN estimates the conditional probability that a given new point belongs to a certain class with the proportion of that class in the collection of k nearest points in the training set. It then chooses the one with the highest proportion. Below is a good visual representation.
Variations in the kNN algorithm are based on both the distance metric and whether it is used as a regression or classification method. For our thought experiment, we will only consider the classification case. There are also many choices of distance metric available (like Euclidean and Manhattan, to name a few) that I will not elaborate upon here. For the sake of this thought experiment, remember that the only input parameters to this algorithm are k (number of neighbors to consider) and the distance metric.
Interview Question #1: What are potential problems with implementing kNN on a very large dataset?
Answer:
One must understand what operations happen during each iteration of the algorithm. For each new data point, the kNN classifier must:
- Calculate the distances to all points in the training set and store them
- Sort the calculated distances
- Store the K nearest points
- Calculate the proportions of each class
- Assign the class with the highest proportion
Obviously this is a very taxing process, both in terms of time and space complexity. The first operation is a quadratic time process, and the sorting a O(nlogn) process. Together, one could say that the process is a O(n³logn) process; a monstrously long process indeed.
Another problem is memory, since all pairwise distances must be stored and sorted in memory on a machine. With very large datasets, local machines will usually crash.
Interview Question #2: What are some ways of getting around the kNN-specific problems?
Answer:
Solution #1: Get more resources (computing power or larger memory).
This is obviously not the best answer to a scalability question, and not really applicable in real-life, industry problems.
Solution #2: Preprocessing the data.
Dimensionality reduction (via PCA (principal component analysis), or feature selection) to reduce the complexity of the distance calculation. You can also use clustering algorithms (like K-means or Rocchio) to reduce the number of points used to compute distances and sort, as illustrated below. In this case, the nontrivial task becomes assigning the test set point to the correct cluster.
Solution #3: Random sampling to reduce training set size.
If you use a good random number generator and a decently large sample size, the sample should be a fairly good representation of the original.
Interview Question #3: If using random sampling only once and supposing we know a good k value to use for the original data, how should k be adjusted in accordance to a change in the input size?
Answer:
Two important points must be clarified to tackle this problem:
- What effect does sampling have on the kNN model?
- What effect does changing k have on the kNN model?
Point #1: Effects of sampling:
As illustrated above, sampling does several things in the perspective of a single data point, since kNN works on a point-by-point basis.
- The average distance to the k nearest neighbors increases due to increased sparsity in the dataset.
- Consequently, the area covered by k-nearest neighbors increases in size and covers a larger area of the feature space.
- The sample variance increases.
A consequence to this change in input is an increase in variance. When we talk of variance, we refer to the variability in the predictions given different samples from the population. Why would the immediate effects of sampling lead to increased variance of the model?
Notice that now a larger area of the feature space is represented by the same k data points. While our sample size has not grown, the population space that it represents has increased in size. This will result in higher variance in the proportion of classes in the k nearest data points, and consequently a higher variance in the classification of each data point.
Point #2: Effects of Changing the k Parameter in kNN
Let us first examine the visual changes of changing k from k=1 to k=5 on a particular dataset.
Notice from the comparison that:
- The number of distinct regions (in terms of color) goes down when the k parameter increases.
- The class boundaries of the predictions become more smooth as k increases.
What really is the significance of these effects? First, it gives hints that a lower k value makes the kNN model more “sensitive.” That is, it is more sensitive to the local changes in the dataset. The “sensitivity” of the model directly translates to its variance.
All of these examples point to an inverse relationship between variance and k. Additionally, consider how kNN operates when k reaches its maximum value, k=n, where n is the number of points in the training set) In this case, the majority class in the training set will always dominate the predictions. It will simply pick the most abundant class in the data, and never deviate, effectively resulting in zero variance.Therefore it seems to reduce variance, k must be increased.
Final Verdict: In order to offset the increased variance due to sampling, k can be increased to decrease model variance.
Interview Question #4: If not restricted to a single sample, what could be a fairly simple method to reduce increased variance of the model other than changing k?
Answer:
If not restricted in the number of times, one can draw samples from the original dataset, a simple variance reduction method would be to sample,many times, and then simply take a majority vote of the kNN models fit to each of these samples to classify each test data point. This variance reduction method is called bagging. You might have heard of bagging, since it is the core concept in randomforest, a very popular tree ensemble method. We will explore this technique in greater detail in future posts.
Conclusion:
Why was kNN chosen to be the problem? Not because it’s the most relevant algorithm to our team’s work, but because it’s a very simple model to understand.
These questions were presented to a broad spectrum of student applicants, from seniors, graduate students, to freshmen and sophomores. Not wanting to penalize people for lack of theoretical knowledge, we picked an algorithm that did not require any rigorous mathematical articulations, but did require good statistical intuition and visual thinking. While some students arrived at different answers from those presented here, those that were based on solid statistical counterarguments were accepted and even appreciated!
While kNN may not be the most popular or relevant algorithm in data science today, the thought process laid out in this post applies to every machine-learning algorithm, and therefore must be in the skill set of every data scientist. Data manipulation and transformation is intrinsically tied to the decision one must make in parameter adjustments.
P.S. Have any suggestions on how to solve these questions? Any thoughts and comments would be much appreciated.