State monad

During some recent work I found the need to use the State monad. Unfortunately I had no idea how to do this.

In this post we’ll retrace the steps I took while trying to generate random strings with Haskell. We’ll start by coding specific, non-monadic pieces to solve the problem, then generalise this by implementing our own State monad, before finally switching over to use Haskell’s implementation. By the end of it we’ll hopefully all understand the State monad (which is more than I can say for when I first started tackling this problem :)).

If you’ve never encountered monads before then you may want to start with my introduction to monads (yes, yes, Yet Another Monad tutorial. Feel free to read a good one instead if you don’t like mine :)).

Aside: I think we can actually use QuickCheck to generate arbitrary strings for us. But where’s the fun in that?

Random characters

To get a (pseudo-)random value we can use Haskell’s System.Random module:

ghci> import System.Random
ghci> :t randomR
randomR :: (RandomGen g, Random a) => (a, a) -> g -> (a, g)

The randomR function takes as input a tuple with the lower and upper-bounds of the random value to generate, and a random number generator. Its output is a tuple consisting of a random value within those bounds, and a new generator which will generate the next random value in the sequence. (This will work for any type a that is an instance of the Random type class, and any generator g which is an instance of RandomGen.) We can use it like this:

ghci> let g = mkStdGen 42
ghci> randomR (1,10) g
(2,1720602 40692)
ghci> randomR ('a','z') g
('n',1720602 40692)

Here we’ve seeded a new generator with the number 42 and generated a number between (1,10) to get 2, and a character between ('a','z') to get 'n'. So we’ve got random character generation under control.

Multiple random characters

To get a new random value we’re going to have to use randomR again, but this time with the new generator from the tuple outputted by our previous call. (Remember Haskell functions are pure; their output depends purely on their input. If we pass the same generator mkStdGen 42 as input we’ll get exactly the same “random” character as output.) We’ll have to keep feeding each generator that comes out of one call as input to the next call:

ghci> let (a', g') = randomR ('a','z') g
ghci> let (a'',g'') = randomR ('a','z') g'
ghci> let (a''',g''') = randomR ('a','z') g''
ghci> let (a'''',g'''') = randomR ('a','z') g'''
ghci> let (a''''',g''''') = randomR ('a','z') g''''
ghci> [a', a'', a''', a'''', a''''']
"ndfeo"

To generate a string of some length i, we’ll need to call this randomR ('a','z') function i times, accumulating both the string value and the generator at each step. Let’s hack this out:

import System.Random

getRandomChar :: RandomGen g => g -> (Char, g)
getRandomChar = randomR ('a','z')

getRandomStringWithLength :: RandomGen g => Int -> g -> (String, g)
getRandomStringWithLength i g =
    let charGenerators = replicate i getRandomChar
    in  foldr apply ([], g) charGenerators

apply :: (a -> (b, a)) -> ([b], a) -> ([b], a)
apply f (xs, g) = let (x, g') = f g
                  in  (x:xs, g')

Here we’ve used the name getRandomChar to refer to the randomR ('a','z') function which will give us a single character plus a new generator.

We’ve also got a getRandomStringWithLength function which takes an integer (how many characters long the string should be) and a generator, and outputs the resulting string and the last generator produced. It calls replicate i getRandomChar, which will output a list of getRandomChar functions. If we call replicate 3 getRandomChar, we’ll get [getRandomChar, getRandomChar, getRandomChar]. We need to reduce this list, accumulating characters and the most recent generator as we go.

For that getRandomStringWithLength uses a fold. The fold’s initial value (seed) is an empty list and the first generator passed in (([],g)), and from there we apply each of the charGenerators.

The apply function is fairly hideous. It takes the current getRandomChar function and the (string, generator) value accumulated in the fold, applies the function to that generator, and outputs a tuple of (newString, newGenerator). But, hideousness aside, we now have a function that will generate a random string of characters:

ghci> getRandomStringWithLength 5 (mkStdGen 42)
("oefdn",1421974012 652912057)

Encapsulating state

So let’s tidy up this mess I’ve made. The getRandomChar function takes some initial generator, and outputs a tuple with a character and a new generator. A more general way to think of this is as a function s -> (a,s), so that for some initial state s, we want to output a value a and a new state of s. We’ll wrap this computation up in a new data type, representing a Stateful computation:

newtype Stateful s a = Stateful { runState :: s -> (a,s) }

Now we can modify getRandomChar to output a Stateful computation, and we can run a state through that computation using runState:

randomChar :: (RandomGen g) => Stateful g Char
randomChar = Stateful $ randomR ('a','z')

-- ghci> runState randomChar (mkStdGen 42)
-- ('n',1720602 40692)

Combining stateful computations

Now what we want to do is to chain together functions which work on states, so that the state output from one function gets fed into the input of the next function (which we saw in Multiple random characters). In my take on monads I concluded “Monads let us apply functions inside a context which depend on previous results”. In this case our next random number depends on the random generator produced by the previous call. So let’s wire up a monad instance based on our Stateful type, so that when we combine two stateful computations, the new state produced by one gets passed as the input to the next. To do this we’ll implement two functions. The first is the bind function >>=:

-- For a monad m:
(>>=) :: m a -> (a -> m b) -> m b

-- For (Stateful s), substitute for m:
(>>=) :: Stateful s a -> (a -> Stateful s b) -> Stateful s b

By substituting our monad type Stateful s into the type signature, we can see bind needs to take a stateful computation which produces an a, and a function which takes an a and outputs a new stateful computation Stateful s b. The entire bind function ultimately outputs a Stateful s b, a stateful computation that combines the result of the initial Stateful s a and the Stateful s b produced from the a -> Stateful s b function, feeding through the correct state at each step.

The second function we need is return, which will take some a and return a new computation that always returns that a:

return :: a -> m a
-- Substituting in for (Stateful s):
return :: a -> Stateful s a

Implementing our state monad

Here’s an implementation that satisfies these type signatures:

instance Monad (Stateful s) where
    -- (>>=) :: Stateful s a -> (a -> Stateful s b) -> Stateful s b
    statefulA >>= f
        = Stateful $ \s ->
                let (a, s')   = runState statefulA s
                    statefulB = f a
                    (b, s'')  = runState statefulB s'
                in (b, s'')

    -- return :: a -> Stateful s a
    return a = Stateful $ \s -> (a, s)

The first argument is a stateful computation that produces an a. Let’s call this statefulA. The second argument is a function f that given an a will produce a stateful computation that will give us a b. The final output will need to be a new Stateful s b, which is a type that holds a function s -> (b, s). So we’ll start by creating an instance of this type, Stateful $ \s -> ....

The function in this new stateful computation, when given a state s, will run statefulA to produce a tuple containing an a and new state s'. If we feed this a into the function f we’ll get a new stateful computation that produces a b, which we’ll call statefulB.

We can’t just output statefulB though. Its type is statefulB :: Stateful s b, but we’re in the middle of implementing a function s -> (b, s) that we’re already wrapping in a Stateful context. We don’t want to end up with a Stateful (Stateful s b)!

So instead we output the result of calling runState statefulB s', which gets a tuple containing a b and the next state s'' out of statefulB by passing through the state s' from the previous call. In this way we ensure that the state transformation from s to s' to s'' is passed through the resulting stateful computation, which was our aim from the beginning.

That’s the bulk of the work done, but for completeness let’s go look at the return function. It takes some value a and returns a Stateful computation that, given a state s, will always produce that value a and the same state s. In other words, it takes an ordinary value and puts it into the context of a stateful computation that takes a state and returns (value, state).

Aside: We should also make Stateful and instance of Functor and Applicative (all monads are also functors and applicative functors), but I’ll omit that step as this post is long enough already. If you’d like me to go through it let me know. We should also make sure that our monad implementation satisfies the Monad Laws, but again, long post, so let’s press on.

Using our shiny new monad instance to generate strings

So we have a randomChar function which produces a Stateful g Char computation, and we have a shiny new monad instance that lets us combine computations. We now have everything we need to run a stateful computation that will produce a random string, or a Stateful g [Char].

Let’s take stock of what we have so far:

getRandomChar :: RandomGen g => g -> (Char, g)
getRandomChar = randomR ('a','z')

getRandomStringWithLength :: RandomGen g => Int -> g -> (String, g)
getRandomStringWithLength i g =
    let charGenerators = replicate i getRandomChar
    in  foldr apply ([], g) charGenerators

-- snip --
newtype Stateful s a = Stateful { runState :: s -> (a,s) }
instance Monad (Stateful s) where
    statefulA >>= f = ...
    return a = ...
-- snip --

randomChar :: (RandomGen g) => Stateful g Char
randomChar = Stateful $ randomR ('a','z')

randomStringWithLength :: (RandomGen g) => Int -> Stateful g String
randomStringWithLength i = undefined -- ???

Our original getRandomStringWithLength function called replicate i getRandomChar to get a [g -> (Char, g)]. If we use our new function, replicate i randomChar, we get [Stateful g Char]. We’d like a way to convert a [Stateful g Char] to a Stateful g [Char].

Haskell has a built-in function called sequence which does just this. It has the signature [m a] -> m [a], for any monad m. Substituting in for our Stateful s monad, when producing a Char, this gives us [Stateful s Char] -> Stateful s [Char]. Now [Char] is the same as String, which means we can generate a stateful computation for a random string from randomChar like this:

randomStringWithLength :: (RandomGen g) => Int -> Stateful g String
randomStringWithLength i = sequence (replicate i randomChar)
                   -- or = (sequence . replicate i) randomChar
                   -- or = replicateM i randomChar

-- ghci> runState (randomStringWithLength 10) (mkStdGen 42)
-- ("ndfeolyrgn",2052659270 1336516156)

I assume this pattern of replicate and sequence is fairly common, as Haskell provides a replicateM function that does just this.

Now with 12% extra randomness!

Let’s randomise the string length:

randomString :: (RandomGen g) => Stateful g String
randomString = state $ \g ->
    let (charsToGen, g') = randomR (1,10) g
    in  runState (randomStringWithLength charsToGen) g'

This works, but we’re back to explicitly passing around the state of the random number generator. That’s no good. Instead we can create a stateful computation for getting a random number between 1 and 10 using Stateful (randomR (1,10)), and bind that to our randomStringWithLength function. This will do the work of passing the random number and new state into our randomStringWithLength function:

randomString :: (RandomGen g) => Stateful g String
randomString = Stateful (randomR (1,10)) >>= (\i -> randomStringWithLength i)
       -- or just:
       --   = Stateful (randomR (1,10)) >>= randomStringWithLength

-- ghci> runState randomString (mkStdGen 1234)
-- ("xxetyzjjrk",995854900 498340277)
-- ghci> runState randomString (mkStdGen 4321)
-- ("jbetvvma",876562129 1422611300)
-- ghci> runState randomString (mkStdGen 1423)
-- ("ikavsm",962541241 535353314)

Using the built-in State monad

Haskell’s built-in state monad is called State (rather than Stateful) and is in the Control.Monad.State module. We can translate our example very quickly by adding an import, changing references from the type Stateful to State, and constructing new stateful computations using the state function:

import Control.Monad.State
import System.Random

randomChar :: (RandomGen g) => State g Char
randomChar = state $ randomR ('a','z')

randomStringWithLength :: (RandomGen g) => Int -> State g String
randomStringWithLength i = replicateM i randomChar

randomString :: (RandomGen g) => State g String
randomString = state (randomR (1,10)) >>= randomStringWithLength

Parting thoughts

Until I walked through the steps I had a tough time understanding how the State monad worked. One thing that helped me to understand it was to force myself to switch between the different levels of abstraction.

First, implementing the Stateful s monad instance was a matter of mechanically substituting in types for >>= and return, then building up expressions that satisfied those types, taking care to make sure I used all the variables involved so I didn’t drop out any terms.

When it came to using that instance, I found it helped to forget how >>= works, and focus instead on what it does: it lets us create a new stateful computation based on the value produced from an existing computation. It seems to be one of those things that is very easy to over-think, and the only way I’ve found around this is to focus on what (not how) the abstraction doing its thing, and following the types.

Another thing I got from working through this is a bit more of a sense of how useful the monad abstraction is; we get a whole lot of useful functions for free, like sequence and replicateM. The State monad itself ensures we can combine these and our own functions while ensuring our state gets reliably passed through the entire computation.

If you’ve found any of this hard to follow please send me an email or leave a comment; I’d be really keen to try and help out.

Comments