Advent of Code 2020 in Haskell - Day 15



Jump to:

D1 | D2 | D3 | D4 | D5 | D6 | D7 | D8 | D9 | D10 | D11 | D12 | D13 | D14 | D15 | D16 | D17 | D18 | D19 | D20 | D21 | D22 | D23 | D24 | D25


Part 1 and 2

Day 15 was one of those days where the extension presented in part 2 is simply to expand the size of the input as sort of "efficiency gate" to make those who solved Part 1 by brute force rethink their algorithm. We need to simulate a sequence-building game involving tracking the last seen occurrence of each integer. Happily, my solution to part 1 was more elegant than a simple brute force, so both parts were solved by the code below:

    
import D15input
-- input :: [Int]
import qualified Data.IntMap.Lazy as Map

getNth :: Int -> Map.IntMap Int -> Int -> Int -> Int
getNth n memo last pos
    | pos == n = last
    | otherwise =
        let updatedMemo = Map.insert last pos memo
            diff = case Map.lookup last memo of
                    Nothing -> 0
                    Just i -> pos - i
        in getNth n updatedMemo diff (pos + 1)

solve :: Int
solve = let initMap = Map.fromList $ zip input [1..]
            in getNth 2020 initMap 0 (length input + 1)


main :: IO ()
main = print $ solve
    

The core of the algorithm is to maintain a map from integer to the index when it was last seen. At each step, we compute the difference (called diff above) by looking up the last occurrence and then performing simple subtraction. If the map does not yet contain the integer, the lookup call will return Nothing, so we can use 0 instead.

This is all wrapped up in a recursive function getNth, which continues performing these steps until the position we are at matches the requested value n. The above solves part 1, and part 2 is solved by simply replacing 2020 with 30000000 in the solve function.

Improving Performance with Mutable Data Structures

The above approach does indeed solve Part 2, but the compiled binary still takes around 35 seconds to return the answer on my machine. Advent of Code guarantees that every problem has a solution that runs in less than 15 seconds on 10-year old hardware, so I wanted to spend some time looking at how to improve performance.

Just to be sure that the algorithmic approach was sound, I threw together the above solution in C++ to see how quickly it ran:

    
#include <iostream>
#include <vector>
    
int step(std::vector<int> &memo, int last, int pos) {
    int diff;
    if (memo[last] == -1) diff = 0;
    else diff = pos - memo[last];
    memo[last] = pos;
    return diff;
}

int main() {
    std::vector<int> memo(30000000, -1);
    memo[0] = 1; memo[14] = 2; memo[6] = 3;
    memo[20] = 4; memo[1] = 5; memo[4] = 6;

    int last = 0;
    for (int i = 7; i < 30000000; ++i) {
        last = step(memo, last, i);
    }
    std::cout << last << '\n';
    return 0;
}
    

The C++ binary gave the answer in under 0.5 seconds. This solution uses an indexed vector to track the last index, which we already expect to be faster than a map, but I tried with unordered_map as well and achieved a runtime of around 3 seconds, which is still much faster than the Haskell binary. Using a vector is a little wasteful in space, but a 30 million element vector isn't ridiculous, even if it does end up being very sparsely populated.

Since the algorithm is identical, I hypothesized that the difference lies in the data structure used. Haskell data structures are by default immutable, and, in general, perform very well, but in a comparison like this, where we have one (very large) data structure being updated in a tight loop, it is hard to imagine that the immutable versions can compete with their mutable counterparts.

Haskell has mutable data structures, but the language forces you to draw a clear line whenever you transition into code that uses them. We turn to the MVector (mutable vector) to try to iterate more efficiently.

    
import D15input
-- input :: [Int]
import Control.Monad
import Control.Monad.ST
import Data.List as List
import Data.Vector.Unboxed.Mutable as MVector

initVector :: Int -> [(Int, Int)] -> ST s (MVector.MVector s Int)
initVector n init = do
    result <- MVector.replicate 30000000 (-1)
    forM_ init $ \(k, v) -> do 
        MVector.write result k v
    return result

run :: Int -> (MVector.MVector s Int) -> Int -> Int -> ST s Int
run n v last pos
    | pos == n = pure last
    | otherwise = do
        prevPos <- MVector.read v last
        diff <- pure $ if prevPos == (-1) then 0 else (pos - prevPos)
        MVector.write v last pos
        run n v diff (pos + 1)

getNth :: Int -> [(Int, Int)] -> Int -> Int -> Int
getNth n init last pos = 
    runST $ do
        v <- initVector n init
        run n v last pos

solve :: Int 
solve = getNth 30000000 (List.zip input [1..]) 0 (List.length input + 1)

main :: IO ()
main = print $ solve
    

A quick disclaimer: I am quite inexperienced with writing in this style, so there will surely be various places where the logic can be more clearly expressed. The algorithm is the same, however, and I tried to keep the parallels with the immutable map code as strong as possible.

First, the results: compiled at -O2, the mutable version takes somewhere between 0.6s and 0.7s to solve part 2! Not quite as fast as the C++ version, but we are at least in the same order of magnitude now.

When working with MVector, we need to perform our computations within a state monad such as ST. We first create an initVector function that constructs our initial MVector from the input. This function reads very much like an imperative function, looping over the initial input and writing each index into the MVector. Note that it returns not an MVector s Int, but rather an ST s (MVector s Int), effectively signaling that this function is to be used in other functions within the ST monad.

The run function is the ST monad version of the part 1 solution. If we have reached our target, we can return the last element (wrapped using pure, to convert its type form Int to ST s Int). Otherwise, get the last position, compute the next element, update the data structure, and then recurse. Again, this function returns an ST s Int, rather than an Int.

These two functions are put together in getNth, which somehow manages to return a plain Int, even though it is composed of functions that return ST s-wrapped values. getNth constructs its own anonymous function via do notation that simply calls run after calling initVector. This is a function that returns an ST s Int, but the ST s Int is unwrapped by the runST function to obtain the actual Int result.

All of this took quite a while for me to understand when I was first reading up about it, and I'm still not sure I fully understand what is going on behind the scenes here. The mental model that has served me well so far is that the ST monad is used to contain "impure" operations such as mutating state. When we're working within this monad, we can perform actions that modify our MVector. For example, let's look at the write call, whose return type is m (), or ST s (). This is Haskell's way of saying no meaningful return value, which makes sense, given that it is a function that we call for the "side-effect" only. Once we have an ST s Int, we can run that through the special runST function in order to extract the final result.


Jump to:

D1 | D2 | D3 | D4 | D5 | D6 | D7 | D8 | D9 | D10 | D11 | D12 | D13 | D14 | D15 | D16 | D17 | D18 | D19 | D20 | D21 | D22 | D23 | D24 | D25