Computer Science, Haskell, Programming, Type Theory
By 05st
In the previous part, we set up data types to represent our syntax, and discussed how the Hindley-Milner type system is defined. We also went over polymorphism vs. monomorphism, generalization, and instantiation. In this part, we will briefly go over what substitution and unification are, and finally begin the main implementation of the type system.
You most likely already have an idea as to what a substitution is. Basically, it's a mapping from symbols to other symbols. To apply a substitution means to "consistently" replace a symbol to what it is mapped to in the substitution. Examples of applying substitutions:
The substitutions appear inside the square brackets (e.g. ). Multiple substitutions are separated by commas. It's also common to see notation used for substitutions. We will represent substitutions as a Data.Map
, with type variables for the keys and types for the values.
We will also be using the function to retrieve a set of all the free type variables of a type. It will be implemented with Data.Set
. The definition will be as follows:
(Where is a type constructor, and is the nth type parameter.)
Unification is the process by which we will build up a substitution. We don't need to know what unification is formally. All our algorithm needs to do is take a constraint, like an equality constraint, and attempt to unify the two terms. For example, if one of the two terms is a type variable, e.g. , then the resulting substitution can simply be or . We will stick to the former, but it's mostly a matter of preference.
We will use the notation to indicate that and are unifiable by the substitution , which means . With that, the unification rules are as follows:
The reason we check if occurs in the free type variables of for and is because otherwise it would attempt to construct an infinite type. For example, try unifying . You get the substitution . If we try to apply the substitution, we get the following:
To avoid this infinite loop, we perform something called the occurs check, which is what you see written as .
Finally, with all of the explanation out of the way, we can begin to implement the type system. This is where all of it will hopefully fit together and make sense. The specific algorithm we will be implementing is known as Algorithm W. This algorithm differs from Algorithm J by the fact that we build up a set of constraints first, then perform unification and substitution. In Algorithm J, you would unify and substitute while generating the constraints.
To begin, let's go over the monad transformer stack we will be using for the type inference. The main monad transformer will be RWST
which is equivalent to a ReaderT
, WriterT
, StateT
stack. The monad it will transform will be an Except String
, you can add a custom error type if you wish.
The environment for the reader monad will be the typing context, a Data.Map
from Name
s to Scheme
s. Scheme
will be how we represent type schemes (polytypes). It's just going to be a Data.Set
of type variables, and a type which those type variables are closed over.
The writer will be adding to a list of equality constraints, which we will represent as another data type. We will also define a function constrain
which helps that process.
The state monad will keep a count (Int
) of the fresh type variables generated, so we know what fresh type variable to return next time. The function fresh
will handle generating the fresh type variables.
One last thing, enable the LambdaCase
language extension. It will allow us to write slightly cleaner code. So far, we have:
{-# Language PatternSynonyms #-}
{-# Language LambdaCase #-}
{- ... Imports from before ... -}
import qualified Data.Map as Map
import qualified Data.Set as Set
import Control.Monad.RWS
import Control.Monad.Except
{- ... Code from before ... -}
data Scheme = Forall [Set.Set TVar] Type
data Constraint = Constraint Type Type
type Context = Map.Map Name Scheme
type Count = Int
type Constraints = [Constraint]
type Infer a = RWST Context Constraints Count (Except String) a
constrain :: Type -> Type -> Infer ()
constrain = tell . (:[]) . Constraint
fresh :: Infer Type
fresh = do
count <- get
put (count + 1)
return . TVar . TV $ show count
Let's also define a Substitutable
typeclass which will define apply
for applying substitutions and tvs
for querying the free type variables. As stated above, we will represent substitutions as simply a Data.Map
from type variables to types. The tvs
function will give us a Data.Set
of type variables, just like how we defined it above. A function called compose
will be defined to make composing substitutions easier.
type Subst = Map.Map TVar Type
compose :: Subst -> Subst -> Subst
compose a b = Map.map (apply a) (b `Map.union` a)
class Substitutable a where
apply :: Subst -> a -> a
tvs :: a -> Set.Set TVar
We can then make a few data types instances of Substitutable
. We have already gone over the rules for querying type variables, and the implementations for apply
are mostly straightforward, so I won't go over the implementations.
instance Substitutable Type where
tvs (TVar tv) = Set.singleton tv
tvs (TCon _ ts) = foldr (Set.union . tvs) Set.empty ts
apply s t@(TVar tv) = Map.findWithDefault t tv s
apply s (TCon c ts) = TCon c $ map (apply s) ts
instance Substitutable Scheme where
tvs (Forall vs t) = tvs t `Set.difference` vs
apply s (Forall vs t) = Forall vs $ apply (foldr Map.delete s vs) t
instance Substitutable Constraint where
tvs (Constraint t1 t2) = tvs t1 `Set.union` tvs t2
apply s (Constraint t1 t2) = Constraint (apply s t1) (apply s t2)
instance Substitutable a => Substitutable [a] where
tvs l = foldr (Set.union . tvs) Set.empty l
apply s = map (apply s)
Next up, we must implement generalization and instantiation. Both processes were discussed in part one. The generalize
function will take in a context and a type, and result in a type scheme. The instantiate
function will take in a type scheme, and return a type after filling in fresh type variables for all of the type variables it is closed over. We fill in by simply applying a substitution.
generalize :: Context -> Type -> Scheme generalize ctx t = Forall (tvs t `Set.difference` tvs (Map.elems ctx)) t instantiate :: Scheme -> Infer Type instantiate (Forall vs t) = do let vars = Set.toList vs ftvs <- traverse (const fresh) vars let subst = Map.fromList (zip vars ftvs) return $ apply subst t
We can begin the main expression inference function now. It's going to be of type Expr -> Infer Type
. All it does is work backwards through the inference rules, which we discussed previously (go back and review from part one if needed). Inference explanation for if-expression is commented, although it should be fairly obvious.
infer :: Expr -> Infer Type infer = \case EInt _ -> TInt -- Integer literal EBool _ -> TBool -- Boolean literal EVar v -> do ctx <- ask -- Retrieve context from Reader case Map.lookup v ctx of Just t -> instantiate t -- Instantiate type scheme for use Nothing -> throwError $ "Undefined variable " ++ v -- Variable not defined EIf c a b -> do ct <- infer c -- Infer type of condition expression at <- infer a -- Infer type of main branch bt <- infer b -- Infer type of else branch constrain ct TBool -- Condition expression should be a Bool constrain at bt -- Branches should be of same type return at -- Return type of any branch EAbs p e -> do pt <- fresh -- Generate fresh type variable for param let ps = Forall Set.empty pt et <- local (Map.insert p ps) (infer e) -- Infer function definition with param defined return $ pt :-> et -- Function has type pt -> et EApp f a -> do ft <- infer f -- Infer type of expression being called at <- infer a -- Infer type of argument rt <- fresh -- Fresh type variable for result type constrain ft (at :-> rt) return rt EBin o a b -> do let ot = -- Operators are functions case o of Add -> TInt :-> (TInt :-> TInt) Sub -> TInt :-> (TInt :-> TInt) -- NOTE: ADD MORE OPERATORS!! at <- infer a -- Infer left operand bt <- infer b -- Infer right operand t <- fresh -- Result type constrain ot (at :-> (bt :-> t)) return t ELet v e b -> do et <- infer e -- Infer variable type ctx <- ask let es = generalize ctx et -- Generalize variable type bt <- local (Map.insert v es) (infer b) -- Infer body with variable defined return bt
That's it for the main type inference function. Read over it a few times to make sure you understand it. All that's left for us to implement now is the unification process. That is more or less going to be a direct translation of the definition from above into code. We should also use a different monad for the entire solving process, the Except
monad. A helper function named bind
will be defined, for code reuse purposes. It will perform the occurs check.
type Solve a = Except String a
unify :: Type -> Type -> Solve Subst
unify a b | a == b = return Map.empty -- Same
unify (TVar v) t = bind v t -- Var Left
unify t (TVar v) = bind v t -- Var Right
unify a@(TCon n1 ts1) b@(TCon n2 ts2) -- Arrow (->) / other TCons
| a /= b = throwError $ "Type mismatch " ++ show a ++ " and " ++ show b
| otherwise = unifyMany ts1 ts2
where
unifyMany [] [] = return Map.empty
unifyMany (t1 : ts1) (t2 : ts2) = do
s1 <- unify t1 t2
s2 <- unifyMany (apply s1 ts1) (apply s1 ts2)
return (s2 `compose` s1)
bind :: TVar -> Type -> Solve Subst
bind v t
| v `Set.member` tvs t = throwError $ "Infinite type " ++ show v ++ " ~ " ++ show t -- Occurs check
| otherwise = return $ Map.singleton v t
Finally, we should define a function which takes in a list of constraints and builds a final substitution through unify
. A helper function to run solve
will also come in handy.
solve :: Subst -> [Constraint] -> Solve Subst solve s [] = return s solve s ((Constraint t1 t2) : cs) = do s1 <- unify t1 t2 solve (s1 `compose` s) (apply s1 cs) runSolve :: [Constraint] -> Either TypeError Subst runSolve = runExcept . solve Map.empty
With all of that, we're finally finished! A Hindley-Milner type system fully implemented in Haskell.