Advent of Code 2020 in Haskell - Day 18



Jump to:

D1 | D2 | D3 | D4 | D5 | D6 | D7 | D8 | D9 | D10 | D11 | D12 | D13 | D14 | D15 | D16 | D17 | D18 | D19 | D20 | D21 | D22 | D23 | D24 | D25


Part 1

Day 18 asks us to evaluate arithmetic expressions. Seeing this problem was exciting, as building and parsing expression trees is a task that is perfect for Haskell.

First, let us define the expression tree structure, as well as how to evaluate it.

    
import D18input
import Data.Char

data Expr = Digit Char | Add Expr Expr | Mult Expr Expr deriving Show

eval :: Expr -> Int
eval (Digit c) = digitToInt c
eval (Add e1 e2) = (eval e1) + (eval e2)
eval (Mult e1 e2) = (eval e1) * (eval e2)
    

Our expression tree data type is defined recursively. An expression is either a single digit, or it is the sum of two expressions, or it is the product of two expressions. For those seeing this sort of definition for the first time, it can take a while to wrap your mind around it, but you can convince yourself that any arithmetic expression in the problem can be expressed in this way.

The evaluation of this tree can be similarly written recursively. If we are evaluating a digit, just convert it to its numeric value. If we are evaluating an addition, evaluate the left and right subexpressions and then add the results together. If we are evaluating a multiplcation, again evaluate the subexpressions and then multiply. At this point, we can evaluate arbitrary expressions, assuming we can build the Expr tree. For example, consider the expression 1 + (2 * (3 + 4)) = 15:

    
*Main> eval $ Add (Digit '1') (Mult (Digit '2') (Add (Digit '3') (Digit '4')))
15        
    

Now we need to actually construct the Expr from the string representing the expression. We will do some light preprocessing to make our job easier.

    
preprocess :: String -> String
preprocess = map flipBrackets . reverse . filter (/=' ')
    where flipBrackets c = if c == '('
                            then ')'
                            else if c == ')'
                                then '('
                                else c
    

Our job is simplified by the fact that every numeral in the input is a single digit, so we can treat each character as its own token (i.e., if we see a '3', we know it represents the value 3, and don't need to look ahead to see if it's followed by a '5' in which case it's actually part of a larger number). Therefore, we can first filter out the spaces.

Then, we reverse the string since the problem states that expressions are left-associative. This reversal comes from the parsing strategy outlined below, which will naturally parse expressions in a right-associative way. Since the operations here (addition and multiplication) are commutative, we can obtain a left-associative parse by simply reversing the string first.

The final step is to flip the brackets around after the reverse so that our expression still looks like a standard mathematical expression (otherwise "(3 + 5)" becomes ")5+3("). This is not strictly necessary, as we could simply flip our interpetation of the brackets instead. However, the flip is cheap here, and it makes the rest of our code easier to read (and makes debugging easier as well).

    
findMatching :: String -> Int -> String
findMatching ('(':cs) rem = '(':(findMatching cs (rem + 1))
findMatching (')':cs) rem = if rem == 1 
                            then "" 
                            else ')':(findMatching cs (rem - 1))
findMatching (c:cs) rem = c:(findMatching cs rem)

headTerm :: String -> String
headTerm (c:cs)
    | c == '(' = findMatching cs 1
    | otherwise = [c]

stripTerm :: String -> (String, String)
stripTerm s = let t = headTerm s
                    l = if head s == '(' then (length t) + 2 else 1
                in (t, drop l s)
    

With this preprocessing done, our next step is to write the function stripTerm that takes an expression string and separate out the first term, returning both the first term and the remainder of the expression. If the first term is a parenthesized expression, it will take the entire parenthesized expression (since that is the first term), strip off the outer brackets, and return that as the first term.

When the first term is a plain digit, the job of stripTerm is easy. When the first character is a '(' instead, we need to find the matching ')' to know what the entire first term is. This is done by the findMatching function.

findMatching is an adaptation of the famous bracket matching problem. The naive approach of scanning forward until we see a ')' does not work, because we may have encountered additional '(' along the way, in which case the first ')' matches one of those, rather than the one we want to match. findMatching handles this by tracking an additional Int, which counts the number of open brackets that have been seen so far. As we continue down the string, when we see an open bracket, we increase this count. When we see a closing bracket, we decrease the count, declaring a match when the count would decrease to zero.

    
parse :: String -> Expr
parse (c:[]) = Digit c
parse s = case stripTerm s of
    (first, "") -> parse first
    (first, rest) -> let op = head rest
                         second = tail rest
                      in case op of
                            '+' -> Add (parse first) (parse second)
                            '*' -> Mult (parse first) (parse second)

-- input :: [String]

solve :: Int
solve = sum $ map (eval . parse . preprocess) input

main :: IO ()
main = print $ solve      
    

Finally, we can write the parse function that builds our Expr tree. In the base case, our string is a single digit, so our tree contains only one Digit. Otherwise, each parse step starts with using stripTerm to get the first term of the expression. If that was the entire expression (i.e., the remainder of the expression string is ""), we simply parse it again. This is the case that allows us to handle strings like "(3+5)". The first term is the entire "(3+5)", and parse will be called again on "3+5". In the general case where we have a remainder string, we are guaranteed that the remainder string starts with either '+' or '*', and is followed by a complete second expression. We destructure that into op and second and then build this level of our Expr tree, using parse to create the sub-Exprs from the expression substrings.

The actual solve function is now very simple. For each string in the input, perform the preprocess step, parse it into an expression tree, and then evaluate it. Then, sum all of the results.

Part 2

Part 2 asks us to deal with operator precedence. Instead of addition and multiplication having the same precedence, we need to be sure we perform the additions first. This only changes how we build our Expr tree - once we have a tree, we evaluate it in exactly the same way.

Since we put so much effort into writing the parser and evaluator in part 1, it would be nice if we could re-use that. Rather than modify our parse function to consider operator precedence, we opt to modify our input string to insert brackets so that multiplications occur after additions. If we replace every '*' with the string ")*(", we are effectively saying "evaluate the left and right sides of a multiplication before performing the multiplication itself." In order to keep the expression balanced, we also need to duplicate every bracket and wrap the entire expression in one more set of brackets.

Interestingly, I learned during this challenge that this approach was taken by some early compilers. A few example transformations:

    
1+2*3+4 -> (1+2)*(3+4)
5*6*7*8 -> (5)*(6)*(7)*(8)
(5*6)+7*4+3 -> (((5)*(6))+7)*(4+3)
    

In other words, for part 2, all we need to do is swap out the preprocess step, implementing this replacement. After that, everything else we implemented just works!

    
preprocess :: String -> String
preprocess = ('(':) . 
                (++")") . 
                concat . 
                map bracketPrecedence . 
                reverse . 
                filter (/=' ')
    where bracketPrecedence c = if c == '('
                            then "))"
                            else if c == ')'
                                then "(("
                                else if c == '*'
                                        then ")*("
                                        else [c]       
    

Jump to:

D1 | D2 | D3 | D4 | D5 | D6 | D7 | D8 | D9 | D10 | D11 | D12 | D13 | D14 | D15 | D16 | D17 | D18 | D19 | D20 | D21 | D22 | D23 | D24 | D25