About the simplest thing you can do with MCMC is unmix a mixture of Gaussians. I thought this would be very easy, but it turns out there are way more ways to do it wrong than to do it right...

Let's start with some basic example code that finds the means of two mixed Gaussians: http://stackoverflow.com/questions/21005541/converting-a-mixture-of-gaussians-to-pymc3. This code has at least been looked at and bugfixed by PyMC3 nabob Chris Fonnesbeck, so I expect it's a pretty reasonable way to do this.

In [77]:
from __future__ import print_function
import pymc as pm
import numpy as np
from matplotlib import pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
In [98]:
# Example starting code in this cell
n1 = 500
n2 = 200
n = n1+n2

mean1 = 21.8
mean2 = 42.0

# precision = 1/sigma^2
precision = 0.1
sigma = np.sqrt(1 / precision)

print("sigma: %s" % sigma)

data1 = np.random.normal(mean1,sigma,n1)
data2 = np.random.normal(mean2,sigma,n2)

data = np.concatenate([data1 , data2])

with pm.Model() as model:
    p = pm.Uniform( "p", 0 , 1) #this is the fraction that come from mean1 vs mean2

    ber = pm.Bernoulli( "ber", p = p, shape=len(data)) # produces 1 with proportion p.

    sigma = pm.Uniform('sigma', 0, 100)
    precision = sigma**-2

    mean = pm.Normal( "mean", 0, 0.01, shape=2 )

    mu = pm.Deterministic('mu', mean[ber])

    process = pm.Normal('process', mu=mu, tau=precision, observed=data)

with model:
    step1 = pm.Metropolis([p, sigma, mean])
    step2 = pm.BinaryMetropolis([ber])
    trace = pm.sample(10000, [step1, step2])
    pm.traceplot(trace[1000:][::10])
sigma: 3.16227766017
 [-----------------100%-----------------] 10000 of 10000 complete in 11.8 sec

OK, bad start. The model completely fails to find the two Gaussians. It thinks both Gaussians are on top of each other, half-way between the two true clusters and with a large sigma. Increasing the number of samples above 10000 does not fix it. It's pretty happy in this rut and it won't budge.

A simpler problem

Let's back up a bit and make the problem a bit more like a textbook example. First we'll make the clusters a bit more obvious. I'll change the means to 20 and 80 instead of 21.8 and 42 and reduce the precision to 0.01. Does that help?

In [94]:
n1 = 500
n2 = 200
n = n1+n2

mean1 = 20
mean2 = 80

# precision = 1/sigma^2
precision = .05
sigma = np.sqrt(1 / precision)
print("sigma: %s" % sigma)

data1 = np.random.normal(mean1,sigma,n1)
data2 = np.random.normal(mean2,sigma,n2)
data = np.concatenate([data1 , data2])

with pm.Model() as model:
    p = pm.Uniform( "p", 0 , 1) #this is the fraction that come from mean1 vs mean2
    ber = pm.Bernoulli( "ber", p = p, shape=len(data)) # produces 1 with proportion p.

    sigma = pm.Uniform('sigma', 0, 100)
    precision = sigma**-2

    mean = pm.Normal( "mean", 0, 0.01, shape=2 )
    mu = pm.Deterministic('mu', mean[ber])
    process = pm.Normal('process', mu=mu, tau=precision, observed=data)

with model:
    step1 = pm.Metropolis([p, sigma, mean])
    step2 = pm.BinaryMetropolis([ber])
    trace = pm.sample(10000, [step1, step2])
    pm.traceplot(trace[1000:][::10])
sigma: 4.472135955
 [-----------------100%-----------------] 10000 of 10000 complete in 11.5 sec
---------------------------------------------------------------------------
LinAlgError                               Traceback (most recent call last)
<ipython-input-94-4d4db1104678> in <module>()
     30     step2 = pm.BinaryMetropolis([ber])
     31     trace = pm.sample(10000, [step1, step2])
---> 32     pm.traceplot(trace[1000:][::10])

/Users/briann/anaconda/lib/python2.7/site-packages/pymc/plots.pyc in traceplot(trace, vars, figsize, lines, combined, grid)
     52                 histplot_op(ax[i, 0], d)
     53             else:
---> 54                 kdeplot_op(ax[i, 0], d)
     55             ax[i, 0].set_title(str(v))
     56             ax[i, 0].grid(grid)

/Users/briann/anaconda/lib/python2.7/site-packages/pymc/plots.pyc in kdeplot_op(ax, data)
     83     for i in range(data.shape[1]):
     84         d = data[:, i]
---> 85         density = kde.gaussian_kde(d)
     86         l = np.min(d)
     87         u = np.max(d)

/Users/briann/anaconda/lib/python2.7/site-packages/scipy/stats/kde.pyc in __init__(self, dataset, bw_method)
    186 
    187         self.d, self.n = self.dataset.shape
--> 188         self.set_bandwidth(bw_method=bw_method)
    189 
    190     def evaluate(self, points):

/Users/briann/anaconda/lib/python2.7/site-packages/scipy/stats/kde.pyc in set_bandwidth(self, bw_method)
    496             raise ValueError(msg)
    497 
--> 498         self._compute_covariance()
    499 
    500     def _compute_covariance(self):

/Users/briann/anaconda/lib/python2.7/site-packages/scipy/stats/kde.pyc in _compute_covariance(self)
    507             self._data_covariance = atleast_2d(np.cov(self.dataset, rowvar=1,
    508                                                bias=False))
--> 509             self._data_inv_cov = linalg.inv(self._data_covariance)
    510 
    511         self.covariance = self._data_covariance * self.factor**2

/Users/briann/anaconda/lib/python2.7/site-packages/scipy/linalg/basic.pyc in inv(a, overwrite_a, check_finite)
    381         inv_a, info = getri(lu, piv, lwork=lwork, overwrite_lu=1)
    382     if info > 0:
--> 383         raise LinAlgError("singular matrix")
    384     if info < 0:
    385         raise ValueError('illegal value in %d-th argument of internal '

LinAlgError: singular matrix

This time it just crashes with a LinAlgError (singular matrix). The only thing I changed here was moving the clusters and reducing their precision a bit! It turns out that the reason it's crashing is that PyMC3 has problems plotting traces for Bernoulli variables. The solution is to skip over the traceplot for "ber" using the vars keyword.

Fixing the traceplot bug

We can still print out summary information for "ber", giving the trace values etc, so it's really just the plotting that fails. I'll leave the summary data below commented out because it's big and you can't truncate it.

In [93]:
n1 = 500
n2 = 200
n = n1+n2

mean1 = 20
mean2 = 80

# precision = 1/sigma^2
precision = .05
sigma = np.sqrt(1 / precision)
print("sigma: %s" % sigma)

data1 = np.random.normal(mean1,sigma,n1)
data2 = np.random.normal(mean2,sigma,n2)
data = np.concatenate([data1 , data2])

with pm.Model() as model:
    p = pm.Uniform( "p", 0 , 1) #this is the fraction that come from mean1 vs mean2
    ber = pm.Bernoulli( "ber", p = p, shape=len(data)) # produces 1 with proportion p.

    sigma = pm.Uniform('sigma', 0, 100)
    precision = sigma**-2

    mean = pm.Normal( "mean", 0, 0.01, shape=2 )
    mu = pm.Deterministic('mu', mean[ber])
    process = pm.Normal('process', mu=mu, tau=precision, observed=data)

with model:
    step1 = pm.Metropolis([p, sigma, mean])
    step2 = pm.BinaryMetropolis([ber])
    trace = pm.sample(10000, [step1, step2])
    pm.traceplot(trace[1000:][::10], vars=["p","sigma","mean"])
    #pm.summary(trace, vars=["ber"])
sigma: 4.472135955
 [-----------------100%-----------------] 10000 of 10000 complete in 12.2 sec

It doesn't crash now, but again it totally fails to find the Gaussians. This time it thinks there's one Gaussian with mean 0 and one at 36 with a large sigma. Huh?

Trying find_MAP

Maybe the model is just poorly initialized. Next idea: let's feed the sampler starting values using pymc.find_MAP and see what the log likelihoods (logp) look like. Maybe the likelihood of the model is -inf, and that's just screwing everything up. logp of -inf can happen if the data is impossible given the model (e.g., negative values where they are not allowed).

In [96]:
n1 = 500
n2 = 200
n = n1+n2

mean1 = 20
mean2 = 80

# precision = 1/sigma^2
precision = .05
sigma = np.sqrt(1 / precision)
print("sigma: %s" % sigma)

data1 = np.random.normal(mean1,sigma,n1)
data2 = np.random.normal(mean2,sigma,n2)
data = np.concatenate([data1 , data2])

with pm.Model() as model:
    p = pm.Uniform( "p", 0 , 1) #this is the fraction that come from mean1 vs mean2
    ber = pm.Bernoulli( "ber", p = p, shape=len(data)) # produces 1 with proportion p.

    sigma = pm.Uniform('sigma', 0, 100)
    precision = sigma**-2

    mean = pm.Normal( "mean", 0, 0.01, shape=2 )
    mu = pm.Deterministic('mu', mean[ber])
    process = pm.Normal('process', mu=mu, tau=precision, observed=data)

with model:
    start = pm.find_MAP()
    
    print("start", [(k, ''.join(str(int(v)) for v in start[k])) for k in ("ber",)])
    print("start", [(k,start[k]) for k in start.keys() if k not in ("ber",)])
    
    print("model", model.logp(start))
    print("ber", ber.logp(start))
    print("p", p.logp(start))
    print("sigma", sigma.logp(start))
    print("mean", mean.logp(start))
    print("process", process.logp(start))
    
    step1 = pm.Metropolis([p, sigma, mean])
    step2 = pm.BinaryMetropolis([ber])
    trace = pm.sample(10000, [step1, step2])
    pm.traceplot(trace[1000:][::10], vars=["p","sigma","mean"])
sigma: 4.472135955
start [('ber', '1111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111')]
start [('p', array(0.5)), ('sigma', array(50.0)), ('mean', array([ 0.,  0.]))]
model -4177.02763767
ber -485.203026392
p 0.0
sigma -4.60517024994
mean -6.4430472524
process -3680.77639378
 [-----------------100%-----------------] 10000 of 10000 complete in 11.9 sec

So find_MAP thinks that a good model is one where all the data are generated by one Gaussian with a big variance (i.e. all the datapoints are in Gaussian 1 and none in Gaussian 0). Well.... that's actually not so crazy. I didn't tell the model that some fraction of the data was generated by each of the Gaussians so why not? I note for future reference that the total logp of this single-Gaussian model is -4000, which is very good for this model.

After MCMC, the trace reverts to the same bad model as we've seen before, with both Gaussians kind of in the middle with large sigmas. (Again, I tried increasing the number of samples to no effect.)

Trying good start values

This time, I'll give the sampler a great shot at winning by starting with a model close to the true model, and see how the logps compare to this MAP starting point.

In [97]:
n1 = 500
n2 = 200
n = n1+n2

mean1 = 20
mean2 = 80

# precision = 1/sigma^2
precision = .05
sigma = np.sqrt(1 / precision)
print("sigma: %s" % sigma)

data1 = np.random.normal(mean1,sigma,n1)
data2 = np.random.normal(mean2,sigma,n2)
data = np.concatenate([data1 , data2])

with pm.Model() as model:
    p = pm.Uniform( "p", 0 , 1) #this is the fraction that come from mean1 vs mean2
    ber = pm.Bernoulli( "ber", p = p, shape=len(data)) # produces 1 with proportion p.

    sigma = pm.Uniform('sigma', 0, 100)
    precision = sigma**-2

    mean = pm.Normal( "mean", 0, 0.01, shape=2 )
    mu = pm.Deterministic('mu', mean[ber])
    process = pm.Normal('process', mu=mu, tau=precision, observed=data)

with model:
    start = pm.find_MAP()
    start["ber"][0:350] = 0.
    start["ber"][350:] = 1.
    start["p"] = 0.5
    start["sigma"] = 5
    start["mean"] = np.array([30, 60])
    
    print("model", model.logp(start))
    print("ber", ber.logp(start))
    print("p", p.logp(start))
    print("sigma", sigma.logp(start))
    print("mean", mean.logp(start))
    print("process", process.logp(start))
    
    step1 = pm.Metropolis([p, sigma, mean])
    step2 = pm.BinaryMetropolis([ber])
    trace = pm.sample(10000, start=start, step=[step1, step2])
    pm.traceplot(trace[1000:][::10], vars=["p","sigma","mean"])
    print("\n", [int(model.logp(t)) for t in trace[1000:][::100]])
sigma: 4.472135955
model -9631.68143668
ber -485.203026392
p 0.0
sigma -4.60517024994
mean -28.9430472524
process -9112.93019279
 [-----------------100%-----------------] 10000 of 10000 complete in 11.3 sec
 [-3654, -3652, -3653, -3654, -3652, -3652, -3651, -3652, -3653, -3652, -3652, -3656, -3652, -3652, -3652, -3652, -3654, -3653, -3652, -3651, -3652, -3657, -3656, -3655, -3653, -3652, -3652, -3655, -3653, -3654, -3653, -3654, -3652, -3658, -3652, -3655, -3652, -3651, -3652, -3656, -3654, -3655, -3652, -3652, -3653, -3652, -3654, -3655, -3652, -3653, -3651, -3651, -3651, -3654, -3655, -3654, -3652, -3654, -3652, -3652, -3654, -3651, -3652, -3652, -3654, -3654, -3654, -3653, -3654, -3653, -3652, -3652, -3652, -3653, -3653, -3652, -3652, -3654, -3653, -3652, -3653, -3653, -3653, -3651, -3655, -3652, -3652, -3652, -3654, -3651]

This time I gave the sampler every advantage. The means, sigma and p all start at reasonable values. The logp of this model at start is -9700, so it's not too far off the MAP logp from before. Nevertheless, the trace ends up with one cluster at 20 and one at 55. The logp, which I print out at regular intervals in the sampling, hovers around -3800.

How about NUTS?

Maybe I am using a bad sampler? Metropolis is not a start-of-the-art sampler like NUTS, so I'll try that instead.

In [99]:
n1 = 500
n2 = 200
n = n1+n2

mean1 = 20
mean2 = 80

# precision = 1/sigma^2
precision = .05
sigma = np.sqrt(1 / precision)
print("sigma: %s" % sigma)

data1 = np.random.normal(mean1,sigma,n1)
data2 = np.random.normal(mean2,sigma,n2)
data = np.concatenate([data1 , data2])

with pm.Model() as model:
    p = pm.Uniform( "p", 0 , 1) #this is the fraction that come from mean1 vs mean2
    ber = pm.Bernoulli( "ber", p = p, shape=len(data)) # produces 1 with proportion p.

    sigma = pm.Uniform('sigma', 0, 100)
    precision = sigma**-2

    mean = pm.Normal( "mean", 0, 0.01, shape=2 )
    mu = pm.Deterministic('mu', mean[ber])
    process = pm.Normal('process', mu=mu, tau=precision, observed=data)

with model:
    start = pm.find_MAP()
    start["ber"][0:350] = 0.
    start["ber"][350:] = 1.
    start["p"] = 0.5
    start["sigma"] = 5
    start["mean"] = np.array([30, 60])
    
    print("model", model.logp(start))
    print("ber", ber.logp(start))
    print("p", p.logp(start))
    print("sigma", sigma.logp(start))
    print("mean", mean.logp(start))
    print("process", process.logp(start))
    
    step1 = pm.NUTS([p, sigma, mean])
    step2 = pm.BinaryMetropolis([ber])
    trace = pm.sample(10000, start=start, step=[step1, step2])
    pm.traceplot(trace[1000:][::10], vars=["p","sigma","mean"])
    #pm.summary(trace, vars=["ber"])
    print([int(model.logp(t)) for t in trace[1000:][::100]])
    print([trace[i]['ber'].sum() for i in np.arange(0,len(trace),100)])
sigma: 4.472135955
model -9444.98912448
ber -485.203026392
p 0.0
sigma -4.60517024994
mean -28.9430472524
process -8926.23788059
 [-----------------100%-----------------] 10000 of 10000 complete in 63.1 sec[-3640, -3642, -3642, -3639, -3641, -3640, -3640, -3640, -3640, -3640, -3639, -3640, -3641, -3639, -3641, -3644, -3639, -3641, -3640, -3639, -3639, -3641, -3641, -3639, -3641, -3640, -3641, -3640, -3640, -3641, -3640, -3640, -3641, -3639, -3639, -3640, -3639, -3641, -3639, -3639, -3641, -3639, -3639, -3640, -3642, -3640, -3641, -3640, -3639, -3642, -3640, -3639, -3639, -3641, -3642, -3640, -3640, -3642, -3640, -3639, -3641, -3641, -3640, -3640, -3641, -3639, -3640, -3640, -3641, -3639, -3640, -3642, -3642, -3641, -3640, -3640, -3640, -3642, -3639, -3639, -3646, -3642, -3642, -3640, -3639, -3646, -3641, -3639, -3640, -3640]
[350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0]

Even NUTS, the hottest sampler that knows all the shortcuts, fails... It's pretty close, but basically the sigma is too large, and only one of the clusters converges on the right position. That's odd, since it appears to be on the right track, and I know that moving to the correct mean would give me a better logp than -3600.

Fixing BinaryMetropolis sampling

But there's also a clue in here. If I print out the sums from trace["ber"] (i.e., how many samples are assigned to gaussian 1) it never deviates from the initialized value of 350, even though it's supposed to be sampling using BinaryMetropolis. It really looks like BinaryMetropolis is not sampling properly. Looking at the function on github shows a scaling parameter that might be helpful https://github.com/pymc-devs/pymc/blob/master/pymc/step_methods/metropolis.py

In [83]:
n1 = 500
n2 = 200
n = n1+n2

mean1 = 20
mean2 = 80

# precision = 1/sigma^2
precision = .05
sigma = np.sqrt(1 / precision)
print("sigma: %s" % sigma)

data1 = np.random.normal(mean1,sigma,n1)
data2 = np.random.normal(mean2,sigma,n2)
data = np.concatenate([data1 , data2])

with pm.Model() as model:
    p = pm.Uniform( "p", 0 , 1) #this is the fraction that come from mean1 vs mean2
    ber = pm.Bernoulli( "ber", p = p, shape=len(data)) # produces 1 with proportion p.

    sigma = pm.Uniform('sigma', 0, 100)
    precision = sigma**-2

    mean = pm.Normal( "mean", 0, 0.01, shape=2 )
    mu = pm.Deterministic('mu', mean[ber])
    process = pm.Normal('process', mu=mu, tau=precision, observed=data)

with model:
    start = pm.find_MAP()
    start["ber"][0:350] = 0.
    start["ber"][350:] = 1.
    start["p"] = 0.5
    start["sigma"] = 5
    start["mean"] = np.array([30, 60])
    print("model", model.logp(start))
    print("ber", ber.logp(start))
    print("p", p.logp(start))
    print("sigma", sigma.logp(start))
    print("mean", mean.logp(start))
    print("process", process.logp(start))
    
    step1 = pm.Slice([p, sigma, mean])
    step2 = pm.BinaryMetropolis([ber], scaling=.01)
    trace = pm.sample(40000, start=start, step=[step1, step2])
    pm.traceplot(trace[20000:][::50], vars=["p","sigma","mean"])
    print([int(model.logp(t)) for t in trace[20000:][::400]])
    print([trace[i]['ber'].sum() for i in np.arange(0,len(trace),400)])
sigma: 4.472135955
model -9422.32659179
ber -485.203026392
p 0.0
sigma -4.60517024994
mean -28.9430472524
process -8903.57534789
 [-----------------100%-----------------] 40001 of 40000 complete in 1153.7 sec[-2937, -2933, -2932, -2932, -2907, -2899, -2876, -2852, -2819, -2819, -2822, -2820, -2793, -2794, -2791, -2754, -2715, -2716, -2719, -2717, -2715, -2716, -2701, -2702, -2700, -2701, -2702, -2700, -2701, -2700, -2691, -2688, -2689, -2657, -2657, -2658, -2661, -2655, -2655, -2656, -2655, -2655, -2605, -2605, -2605, -2605, -2605, -2605, -2552, -2554]
[350.0, 328.0, 310.0, 284.0, 272.0, 278.0, 257.0, 245.0, 247.0, 248.0, 243.0, 241.0, 235.0, 234.0, 227.0, 227.0, 224.0, 225.0, 225.0, 220.0, 221.0, 218.0, 215.0, 216.0, 213.0, 213.0, 212.0, 212.0, 210.0, 210.0, 210.0, 210.0, 209.0, 211.0, 213.0, 213.0, 214.0, 213.0, 213.0, 212.0, 209.0, 209.0, 209.0, 209.0, 205.0, 205.0, 205.0, 203.0, 204.0, 204.0, 203.0, 201.0, 201.0, 199.0, 198.0, 198.0, 199.0, 198.0, 199.0, 199.0, 199.0, 199.0, 198.0, 198.0, 198.0, 199.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 203.0, 203.0, 203.0, 203.0, 203.0, 203.0, 203.0, 203.0, 203.0, 202.0, 202.0, 202.0, 202.0, 202.0, 202.0, 201.0, 201.0]

Finally, after setting BinaryMetropolis' scaling parameter to 0.01, the model converges on the right answer. (I also upped the number of samples a bit to get better convergence.) You can see the sigma continuing to drop right up to the end and the clusters get better defined.

The problem was the BinaryMetropolis sampler. By reducing the scaling factor I tell it to accept way more samples, and this allows it to explore the space properly. It's possible that in general the scaling needs to take into account the size of the data -- in this case, 700 values -- otherwise there are too many rejections.

It's interesting to compare this approach with a simple EM strategy, which is fast, robust and in my experience, just works. If I were writing code that unmixed Gaussians I would certainly stick with EM. MCMC really shines with more complex models, especially hierarchical models.