Low latency adaptive machine learning

Despite all of the diverse applications that machine learning is used in, pretty much all of them have a workflow something like:

  1. Gather data
  2. Label data
  3. Train model
  4. Deploy model

Unfortunately, data drift is something that occurs in almost all domains. Data drift slowly makes your model inaccurate, requiring follow-up iterations of gathering, labelling, training and deployment.

Usually data drift isn’t caused by the optimal output changing for a given input, but rather the inputs starting to explore new areas of the input space that no training data exists for yet. This means the model has to extrapolate from more distant training data, which it may be able to do more or less successfully depending on the exact situation. A classic example here would be that consumer cellphone cameras have gotten better quality over time, so if your model was trained on old blurry camera data, it may now (counterintuitively) be less accurate on the new, clean data, simply because it isn’t in the training data.

Another interesting property of many machine learning problems is that the whole purpose of the prediction is to get some information which will be available anyways shortly afterwards - especially in the realm of time series or behavioural data. These domains are also frequently most affected by model drift, due to anything from new technology being introduced to economic outlooks changing.

The great thing about this situation is that we can actually use this situation to keep the model up to date automatically. By using the temporal nature of the time series or behavior as training data, the actual production model can be taking into account the latest behavioral changes seen in real-world data.

Some types of machine learning systems are more amenable to giving small updates to the production model. Systems like SVMs, random forests and boosting are notoriously difficult. For [D]NN systems, excessive training can cause model collapse, but through judicious use of checkpoints to keep a strong base model, things like LoRA can provide updates for the latest data. However, by far the easiest way I’ve found to do low latency, adaptive machine learning, with the absolutely shortest time between the new samples being ready and the model being updated, is via Kd-Tree powered K-Nearest-Neighbours (KNN).

Due to all of the data being accessible at regression/classification time, for larger problems this implies using some kind of data structure to accelerate searches. As long as you keep your underlying dimensionality under around 8 dimensions (which can be a lot more dimensions in real world data that has internal correlations), using exact lookup strategies is very viable for any dataset that fits in RAM. For larger dimensionality data, it’s typically necessary to go to probabilistic approaches.

Most accelerated KNN implementations (exact and probabilistic) don’t allow for dynamically updating the data without an entire rebuild of the search acceleration structure. Even structure such as Kd-Trees often are written with the API depending on all data being available upfront. However, if you design your structure upfront to allow dynamic data updates, the penalty for adding data ‘online’, ie. on a sample-by-sample basis, while having the tree ready for queries incorporating the total data at any point, is relatively straightforward.

I have two implementations of this data structure, one in Java and another in C++. They are open sourced here.

If you end up using these data structures, I’d love to hear about it and your problems that you are solving.