Long short-term memory models (LSTMs) are a family of neural networks that are predominantly used to predict the next value given a historical chain of previous values. These can be numerical predictions (i.e. where is the stock price going based on historical stock data) or categorical predictions (i.e. what stop is the passenger going to get off at given the previous stops they've gotten off at). In the long long ago, even, the "state of the art" for predicting text with generative neural nets was LSTMs where they were trying to predict the next word given a sequence of previous words. These models are the ones you want to use if you're interested in predicting the next thing that will happen given a complicated chain of things that happened previously, and this explanation / code will teach you how to do it all. Anyways - I've had several prospective clients ask me about using LSTMs to predict things, and I wanted to put together a PoC codebase that illustrates how they work. So I did!

The code I've linked to is a set of classes that jointly (a) create synthetic data in SyntheticDataGenerator comprised of "individuals" transitioning between "events" (i.e. was on the front page of a website A, then went to the about page C, then contact B, then back to front page A), (b) build an LSTM DemoModel that can interpret those events, train a model based on them, and predict next events given a historical chain of events, and (c) run a series of ValidationDiagnostics tests that use out of sample data to confirm that the model does what it says it does. The function DemoModel.run_demo() runs all of the tests end to end, or you can just copy and paste those lines one at a time to see the output.

The setup is - let's build synthetic data that assumes there's eight possible event types, and then we can specify some more-likely-than-chance or less-likely-than-chance transition probabilities between events (all unmentioned transition pairs are assumed to just be x/k where x is the unallocated transition probability and k is the number of ways one can transition from one event to another (i.e. always 7 in this model). In my demo version, I have these probabilities mapped out in the SyntheticDataGenerator:

events = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H']  
transition_dict = {  
    'A': [('B', 0.5), ('C', 0.1), ('D', 0.4)],
    'C': [('D', 0.8)],
    'D': [('A', 0.9)],

That is to say, D→A transitions are very common, as are C→D, while A→C are rare. We also have a tunable parameter memory_influence - when memory_influence is 1.0, it means that a person transitioning through these events will always choose a prior event (i.e. just stay exactly where they started). When memory_influence is 0.0, it means that they will transition entirely randomly according to our map. This setup is important because it helps us do a few things:

  • When memory_influence is very low: this means we are basically just riding along a transition map, where every next step is random subject to our weights in the transition dictionary. As such, if we train an LSTM on the synthetic data, we should see this characteristic being largely recovered (i.e. ≈90% of people at event D will transition to A). Even though the transitions will be reliably correlated to the transition_dict, our predictive power will be relatively low, because there's no dependence on prior states, so the model learns what the transition dictionary says - just choose proportionally, randomly.

  • When memory_influence is very high: we should see our model coming off the peg of some baseline random-dice accuracy. That is to say - if you had a six sided dice, you will be right 1/6th of the time - just like you'd be right 50% of the time with a coin flip. With our more complicated dice of people moving between events, we have some baseline accuracy just because of the proportional allocation of transitions. As we increase memory_influence, we are inducing a "bonus" artifact for the LSTM to learn about - and the higher it goes, the more reliable that artifact becomes for the LSTM in terms of predicting further outcomes, and the higher the accuracy goes. In the trivial case, when memory_influence is 1.0, everyone just stays on the original event they started at, so the LSTM should be scoring at ≈100% by just saying "everyone stays put".

  • For everything in between: we get to see the handoff between baseline accuracy and "memory-boosted" accuracy, and how that comes together as a function of increasing memory_influence. This helps calibrate our understanding as programmers in terms of the order of magnitude of additional lift the model gives in this toy demo circumstance, and helps illustrate what the LSTM is doing above and beyond random chance.

We can measure how well the LSTM performs with a few measures. Obviously, first, how accurate is it? We can just measure that as the number of correctly predicted next states / the number of next states we attempt to predict with some dataset. Of course, with lots of choices, getting it exactly right will become rarer. Maybe instead, we want to see if the second-best choice, third-best choice, k-best choice, and so forth suggested by the LSTM was the best choice. With 8 total events, when k=8, the probability that the choice is in the top 8 is 100% - but it may be interesting to see how, as memory_influence increases, that the percentage correct increases faster with higher values of k. One interesting thing I see with the results, is that it doesn't just get monotonically get better as memory_influence increases.

This may just be an artifact of random chance, but I think it's not, and I think the reason it happens is because the memory_influence parameter actually confounds the model for middling values of memory_influence until it becomes a dominant enough characteristic that the model is better served by relying on the patterns caused by it. Here's a sheet with some diagnostic output, and here's some charts from that sheet for folks too lazy to jump over:

Overall Accuracy as Memory Influence Parameter Changes

K-retrieval accuracy as Memory Influence Parameter Changes

The charts largely confirm the theory though - that as memory_influence increases, after our weird uncanny valley, it becomes a dominant factor that the LSTM can rely on and predict off of - and get increasingly accurate predictions out of. While, obviously, memory_influence is only "sort of" like how people transition between states, it at least is mocking some first-order behaviors (e.g. people will revisit subreddits they've visited much more often than jumping to completely new subreddits).

In my own PoC, I trained an LSTM on ≈50,000 historical bid chains of people bidding over 6 million times on cars on BringATrailer.com. In my PoC, I wanted to demonstrate how one could use LSTMs to build recommendation engines that would, in this case, theoretically show people cars they are likely to bid on, conditional on their previously-bid cars. In my model, I wanted to coarsen the data to not predict the specific car, but instead just predict the make / model - can we, in short, show them the types of cars that will be relevant to their bidding strategies?

I built out a network graph to just visualize the sequential bid behavior - and that clearly showed clusters that made sense (i.e. each cluster of cars fall into fandoms that predominantly are geographically bounded like european cars, american muscle vs american trucks, etc). Given that there was coherence within the data that indicated strong structural preferences that were above and beyond some random behavior, I moved on to building the bid-based LSTM using the same DemoModel code shown in the above Github gist.

Network graph of sequential bid pairs on BringATrailer.com, up to early 2024

K-retrieval chart for BringATrailer.com LSTM bid predictor This LSTM was to predict which of 160 common make/models the user would likely next bid on - and it did in fact learn some patterns (which I again report with a k-retrieval chart below). This took about 6 hours to train on a 4090 on Runpod who I love for training this type of stuff. Overall, this approach with using synthetic data first helps to confirm the model does in fact work from a technical perspective, and helps to ground beliefs as to how much it can help out of the box.

In a further iteration, one could imagine making up a theory of how bid behavior works, generating synthetic data based on that theory, and then testing to see if that has any significant deviation from the observed lift in our original toy model case. In any result, I hope you run with this LSTM demo code and enjoy!