-
Notifications
You must be signed in to change notification settings - Fork 13
/
em_algorithm.tex
319 lines (292 loc) · 13 KB
/
em_algorithm.tex
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
\documentclass[11pt]{article}
% ---------------- Packages ---------------------------
\usepackage[usenames,dvipsnames]{color} % for adding color
\usepackage{amsmath, amsthm, amssymb} % for maths
\usepackage[in,myheadings]{fullpage} % for setting margins
\usepackage{enumerate} % for setting the enumerate bullet style
\usepackage{fancyhdr} % for setting page headers and footers
\usepackage{listings} % for typesetting code
\usepackage{graphicx} % for 'enhanced' graphics package
\usepackage{framed} % for frame borders
\usepackage{float} % for absolute figure placement (H)
\usepackage{url} % for typesetting hyperlinks
\usepackage{bm} % for bold greek letters
% ---------------- Environments -----------------------
\newenvironment{pitemize}{
\vspace{-3mm}
\begin{itemize}
\setlength{\itemsep}{1pt}
\setlength{\parskip}{0pt}
\setlength{\parsep}{0pt}
}{\end{itemize}}
\newenvironment{penumerate}{
\begin{enumerate}
\vspace{-3mm}
\setlength{\itemsep}{1pt}
\setlength{\parskip}{-2pt}
\setlength{\parsep}{-5pt}
}{\end{enumerate}}
\newenvironment{definition}[1][Definition]{
\begin{trivlist}
\item[\hskip \labelsep {\bfseries #1}]
}{\end{trivlist}}
\newenvironment{example}[1][Example]{
\begin{trivlist}
\item[\hskip \labelsep {\bfseries #1}]
}{\end{trivlist}}
\newenvironment{summary}[1][Summary]{
\begin{trivlist}
\item[\hskip \labelsep {\bfseries #1}]
}{\end{trivlist}}
% ---------------- General Macros --------------------
\newcommand{\ul}[1]{\underline{#1}} % underline
\newcommand{\todo}[1]{\textcolor{red}{\textbf{#1}}} % todo
\newcommand{\qpart}[1]{\subsection{#1}\vspace{-3mm}}
\newcommand{\qsection}[1]{\section{#1}\vspace{-3mm}}
% ---------------- Probability Macros -----------------
\renewcommand{\L}[1]{\mathcal{L}} % likelihood
\newcommand{\E}[1]{\mathbb{E}} % expectation
\newcommand{\Cov}[1]{\text{Cov}} % covariance
\newcommand{\Var}[1]{\text{Var}} % variance
\newcommand{\where}[1]{\text{ where }}
\newcommand{\subto}[1]{\text{ subject to }}
% ---------------- Linear Algebra Macros --------------
\newcommand{\bs}{\boldsymbol} % bold greek letters
\newcommand{\norm}[1]{\left\lVert#1\right\rVert} % norm
% ---------------- Number Theory Macros ---------------
\newcommand{\R}{\mathbb{R}}
\newcommand{\N}{\mathcal{N}}
% ---------------- Document Settings ------------------
\setlength{\headheight}{20pt}
\setlength{\textheight}{8in}
\setlength{\footskip}{0.8in}
\setlength{\parskip}{6pt plus 2pt minus 1pt}
\setlength{\parindent}{0pt}
\pagestyle{fancy}
% Headers
\lhead{\\EM Algorithm Primer }
\rhead{\\
David Bourgin}
\begin{document}
The \textbf{EM algorithm} is a method for fitting latent variable models. Let:
\begin{pitemize}
\item $X = \{ x_1, \ldots, x_N \}$ be the collection of observed data points
\item $\theta$ be a collection of model parameters
\item $T = \{t_1, \ldots t_N \}$ be the collection of latent variables
associated with each data point
\end{pitemize}
\section*{Variational Lower Bound on Marginal Likelihood}
In fitting our model we will attempt to find the setting of the parameters
$\theta$ that maximizes the \textbf{marginal likelihood} of the data:
\begin{align*}
P(X|\theta) &= \prod_{i=1}^N P(x_i | \theta) &&\text{Assume data are iid}\\
% &= \prod_{i=1}^N \sum_{c=1}^T P(x_i | t_i = c, \theta) P(t_i = c |
% \theta) \\
&= \prod_{i=1}^N \sum_{c=1}^T P(x_i, t_i = c | \theta)
\end{align*}
As always, it is generally easier to maximize the \textbf{marginal log
likelihood} rather than the likelihood directly:
\begin{align*}
\log P(X|\theta) &= \log \prod_{i=1}^N P(x_i | \theta) &&\text{Assume data
are iid.}\\
&= \sum_{i=1}^N \log P(x_i | \theta)\\
&= \sum_{i=1}^N \log \left[ \sum_{c=1}^T P(x_i, t_i = c | \theta )\right]
\end{align*}
The problem is that this expression is still difficult to optimize directly
(e.g., via SGD). In EM, we opt to instead try to maximize a \textbf{lower
bound}, $\mathcal{L}$ on the marginal log likelihood instead:
\begin{equation*}
\overbrace{\log P(X|\theta)}^{\text{Marginal log likelihood}} \geq \underbrace{\mathcal{L}}_{\text{Variational lower
bound}}
\end{equation*}
The issue is that there is no reason to expect a single lower bound to be useful
for finding a local maxima of the marginal log likelihood. What we really want
is a \emph{family} of lower bounds, which we can tune to get better and better
local approximations to the marginal log likelihood at $\theta$. To achieve this, we introduce a
new parameter to the lower bound, a distribution over the latent variable classes:
\begin{equation*}
q(t_i = c)
\end{equation*}
This distribution will be used as a flexible parameter of our family of lower
bounds, $\mathcal{L}$, allowing us modify the form of the lower bound over the
course of optimization.\\
\\
We derive the form for the family of lower bounds using \textbf{Jensen's
inequality}:
\begin{equation}
\log (\mathbb{E}[X]) \geq \mathbb{E}[\log X]
\end{equation}
or, if we assume $X$ is a discrete random variable:
\begin{equation}
\log \sum_i \alpha_i x_i \geq \sum_i \alpha_i \log
x_i
\end{equation}
where $\alpha_i \geq 0 \ \forall i$ and $\sum_i \alpha_i = 1$. Using this inequality, we can derive a lower bound
on the marginal log likelihood:
\begin{align*}
\log P(X|\theta) &= \sum_{i=1}^N \log \left[ \sum_{c=1}^T
P(x_i, t_i = c | \theta) \right]\\
&= \sum_{i=1}^N \log \left[ \sum_{c=1}^T \underbrace{\frac{q(t_i = c)}{q(t_i
= c)}}_{\text{this is just 1}} \times
P(x_i, t_i = c | \theta) \right]\\
\end{align*}
At this point, notice that we can rewrite the last line as
\begin{equation*}
\log P(X|\theta) &= \sum_{i=1}^N \log \mathbb{E}_q \left[\frac{P(x_i, T |
\theta)}{q(T)} \right]
\end{equation*}
This allows us to apply Jensen's inequality (Eq. 1), to define a family of lower bounds:
\begin{align*}
\log P(X|\theta) &\geq \mathcal{L}(\theta, q)\\
\sum_{i=1}^N \log \mathbb{E}_q \left[\frac{P(x_i, T | \theta)}{q(T)} \right]
&\geq \sum_{i=1}^N \mathbb{E}_q \left[ \log \frac{P(x_i, T | \theta)}{q(T)}
\right]\\
\sum_{i=1}^N \log \left[ \sum_{c=1}^T \frac{q(t_i = c)}{q(t_i
= c)} \times P(x_i, t_i = c | \theta) \right] &\geq \underbrace{\sum_{i=1}^N \sum_{c=1}^T
q(t_i = c) \log \frac{P(x_i, t_i = c | \theta)}{q(t_i =
c)}}_{\mathcal{L}(\theta, q)}\\
\end{align*}
\begin{framed}
\vspace{-5mm}
\begin{summary} \textit{Variational Lower Bound}\\
We have now derived a \emph{family} of lower bounds on the marginal log
likelihood, $\log P(X | \theta)$. The functions in this family are of
the form
\begin{align*}
\mathcal{L}(\theta, q) &= \sum_{i=1}^N \sum_{c=1}^T q(t_i = c) \log
\left[ \frac{P(x_i, t_i = c | \theta)}{q(t_i = c)} \right]\\
&= \sum_{i=1}^N \mathbb{E}_{q} \left[\log \frac{P(x_i, T | \theta)}{q(T)}\right]
\end{align*}
During optimization, we can modify the distribution $q$ to achieve
lower bounds that provide better local approximations to $\log
P(X|\theta)$ at $\theta$.
\end{summary}
\vspace{-2mm}
\end{framed}
\section*{EM Algorithm Overview}
The EM algorithm is an iterative approach to coordinate ascent on the marginal
likelihood, $P(X|\theta)$. It consists of two steps, which are repeated until
convergence:
\begin{enumerate}
\item \textbf{E-Step}: Given a starting value for $\theta$, find the
distribution $q^*$ that maximizes $\mathcal{L}(\theta, q)$.
\item \textbf{M-Step}: Given the distribution $q$ identified during the
E-step, find the value of $\theta^*$ that maximizes $\mathcal{L}(\theta,
q^*)$.
\end{enumerate}
\begin{figure}[htpb]
\centering
\includegraphics[width=\linewidth]{figs/em.png}
\end{figure}
\subsection*{E-Step Details}
During the E-step, we fix the current value for the parameters, $\theta$, and
try to maximize the variational lower bound, $\mathcal{L}$, with respect to the
distribution $q$:
\begin{equation*}
q^{(k+1)} = \arg \max_q \mathcal{L}(\theta^{(k)}, q)
\end{equation*}
This maximization problem is equivalent to \emph{minimizing} the gap between $\log
P(X|\theta^{(k)})$ and $\mathcal{L}(\theta^{(k)}, q)$:
\begin{equation*}
q^{(k+1)} = \arg \min_q \log P(X|\theta^{(k)}) - \mathcal{L}(\theta^{(k)},
q)
\end{equation*}
We can rewrite the gap between $\log P(X|\theta^{(k)})$ and
$\mathcal{L}(\theta^{(k)}, q)$ as:
\begin{align*}
&\log P(X|\theta) - \mathcal{L}(\theta, q)\\
&= \sum_{i=1}^N \log P(x_i|\theta) - \sum_{i=1}^N
\sum_{c=1}^T q(t_i = c) \log \left[ \frac{P(x_i, t_i = c |
\theta)}{q(t_i=c)} \right]\\
&= \sum_{i=1}^N \left( \log P(x_i | \theta) \underbrace{\sum_{c=1}^T q(t_i =
c)}_{\text{this is just 1}} - \sum_{c=1}^T q(t_i = c) \log \left[ \frac{P(x_i,
t_i = c | \theta)}{q(t_i = c)} \right] \right)\\
&= \sum_{i=1}^N \sum_{c=1}^T q(t_i = c) \left( \log P(x_i | \theta) - \log
\frac{P(x_i, t_i = c | \theta)}{q(t_i = c)} \right)\\
&= \sum_{i=1}^N \sum_{c=1}^T q(t_i = c) \log \left[ \frac{P(x_i | \theta)
q(t_i = c)}{P(x_i, t_i = c | \theta) } \right]\\
&= \sum_{i=1}^N \sum_{c=1}^T q(t_i = c) \log \left[ \frac{P(x_i | \theta)
q(t_i = c)}{P(t_i = c | x_i, \theta) P(x_i | \theta)} \right]\\
&= \sum_{i=1}^N \sum_{c=1}^T q(t_i=c) \log \left( \frac{q(t_i =
c)}{P(t_i=c|x_i, \theta)} \right) \\
&= \sum_{i=1}^N \mathbb{KL}(q(t_i) \ || \ P(t_i| x_i, \theta))
\end{align*}
Thus we have that during the E-step,
\begin{align*}
q^{(k+1)} &= \arg \min_q \log P(X|\theta^{(k)}) - \mathcal{L}(\theta^{(k)},
q)\\
&= \arg \min_q \sum_{i=1}^N \mathbb{KL}(q(t_i) \ || \ P(t_i| x_i, \theta))
\end{align*}
Because the smallest KL-divergence is achieved when $q(t_i) = P(t_i| x_i,
\theta)$, we conclude that the update for the \textbf{E-step} is simply:
\begin{equation}
q^{(k+1)} = \underbrace{P(T | X, \theta^{(k)})}_{\text{Posterior over
latent classes}}
\end{equation}
\begin{framed}
\vspace{-5mm}
\begin{summary} \textit{E-Step}\\
During the \textbf{E-step}, we fix the current value for the parameters, $\theta$, and
try to maximize the variational lower bound, $\mathcal{L}$, with respect to the
distribution $q$.\\
\\
Above, we demonstrate that maximizing $\mathcal{L}$ wrt $q$ is
equivalent to minimizing the gap between $\log P(X|\theta)$ and
$\mathcal{L}$, which is in turn equivalent to minimizing the sum of the
KL divergences between $q(t_i)$ and $P(t_i | x_i, \theta)$.\\
\\
This observation implies that the update during the \textbf{E-step}
should simply be:
\begin{equation*}
q^{(k+1)} = P(T | X, \theta^{(k)})
\end{equation*}
The caveat is that often it is intractable to compute the posterior
over latent classes, $P(T | X, \theta)$, exactly. In these cases, it is necessary
to use a variational approximation to $P(T | X, \theta)$ (e.g.,
a mean field approximation), and minimize the KL divergence between $q$
and the variational approximation. This approach is known as \textbf{variational EM}.
\end{summary}
\vspace{-2mm}
\end{framed}
\subsection*{M-Step Details}
During the M-step we fix $q$ to the value we computed during the E-step and try
to find $\theta^{(k+1)}$ that maximizes $\mathcal{L}$:
\begin{equation*}
\theta^{(k+1)} = \arg \max_\theta \mathcal{L}(\theta, q^{(k+1)})
\end{equation*}
Here, we decompose the variational lower bound into terms that depend on
$\theta$:
\begin{align*}
\mathcal{L}(q, \theta) &= \sum_{i=1}^N \sum_{c=1}^T q(t_i = c) \log \left[
\frac{P(x_i, t_i = c | \theta)}{q(t_i = c)} \right] \\
&= \sum_{i=1}^N \sum_{c=1}^T q(t_i = c) \log P(x_i, t_i = c | \theta) -
\underbrace{\sum_{i=1}^N \sum_{c=1}^T q(t_i = c) \log q(t_i = c)}_{\text{does not depend
on $\theta$}}\\
&\propto \sum_{i=1}^N \sum_{c=1}^T q(t_i = c) \log P(x_i, t_i = c |
\theta)\\
&\propto \mathbb{E}_q[ \log P(X, T | \theta) ]
\end{align*}
This expectation, $\mathbb{E}_q[ \log P(X, T | \theta) ]$ is usually concave and
tends to be relatively easy to maximize with respect to $\theta$ for most
models.
\begin{framed}
\vspace{-5mm}
\begin{summary} \textit{M-Step}\\
During the \textbf{M-step} we fix $q$ to the value we computed during
the previous E-step and try to find $\theta^{(k+1)}$ that maximizes
$\mathcal{L}$:
\begin{equation*}
\theta^{(k+1)} = \arg \max_\theta \mathcal{L}(\theta, q^{(k+1)})
\end{equation*}
In the derivation above, we showed that this is equivalent to finding
$\theta$ that maximizes the following expected value:
\begin{equation}
\theta^{(k+1)} = \arg \max_\theta \mathbb{E}_q[ \log P(X, T | \theta) ]
\end{equation}
This is the M-step update, typically achieved by taking the partial
derivative of the above expectation wrt each parameter, setting it to 0,
and solving.
\end{summary}
\vspace{-2mm}
\end{framed}
\end{document}