The EM Algorithm¶
Widely used algorithm for learning directed latent-variable graphical models of the form
\(z\) are latent variables
\(\theta\) are model parameters
Since the latent variables are unobserved, they pose a challenge for optimization.
The idea behind EM is that:
If the latent variables where fully observed, we could optimize the log-likelihood exactly using previously seen closed form solution for \(p(x,z)\)
Knowing the parameters, we can often efficiently compute the posterior \(p(z|x,\theta)\) (for some models this is not true)
The Algorithm works in two steps
Given an estimate \(\theta^t\) of the parameters compute \(p(z|x)\) and use to “hallucinate” values for z.
Find a new \(\theta_{t+1}\) by optimizing the resulting tractable objective.
Hallucinating data means that it computes the mean expected log likelihood:
If z is not too high dimensional we can compute this expectation. And since this expectation is outside the log, we can decompose it into a sum of log conditional probability distributions, that can be optimized independently.
Formal Algorithm¶
For a dataset D we initialize \(\theta_0\) to some random value and for \(t=1,2,\cdots,\) we repeat until converged.
E-step: For each \(x \in D\), compute the posterior \(p(z|x,\theta)\)
M-step: Compute the new parameters via: $\( \theta_{t+1} = \arg \max_{\theta} \sum_{x \in D} E_{z \sim p(z|x_t, \theta_t)} \log p(x,z|\theta) \)$
Detailed overview¶
Let \(x_i\) be visible or observed variable in case i, and let \(z_i\) be the hidden or missing variables. The goal is to maximize the log likelihood function of the observed data:
This is hard to optimize since we cannot push the log inside the sum.
The EM gets arround the problem as follows. Define the complete data log likelihood to be:
This cannot be computed, since \(z_i\) is unknown. So let us define the expected complete data log likelihood as follows:
where t is the current iteration number. Q is called the auxiliary function. The expectation is taken wrt the old parameters, \(\theta^{t -1}\), and the observed D. The goal of the E step is to compute \(Q(\theta, \theta^{t -1})\), or rather, the terms inside of it which the MLE depends onl these are known as the expected sufficient statistics (ESS). In the M step, we optimize the Q function wrt \(\theta\):
To perform MAP estimation we modify the M step as follows:
The E step retains unchanged. The EM algorithm monotonically increases the log likelihood of the observed data.
Theory¶
EM monotonically decreases the observed data log likelihood until it reaches a local maximum.
Connection to Variational Inference¶
We can understand the behavior of EM by casting it in the framework of variational inference.
Properties¶
The marginal likelihood increases after each EM cycle
The marginal likelihood is an upper-bounded by its true global maximum, and it increases at every step, EM must eventually convergence.
Unfortunately here we optimize an non-convex objective, thus may not find an global optimum. In practice we tend to have multiple restarts of the algorithms, and find the best solution.