archives –––––––– @btnaughton

This post is a numpyro-based probabilistic model of my blood glucose data (see my original blogpost for necessary background). The task is, given my blood glucose readings, my heart rate data, and what I ate, can I figure out the contribution of each food to my blood glucose.

This is a classic Bayesian problem, where I have some observed data (blood glucose readings), a model for how food and heart rate influence blood glucose, and I want to find likely values for my parameters.

Normally I would use pymc3 for this kind of thing, but this was a good excuse to try out numpyro, which is built on JAX, the current favorite for automatic differentiation.

To make the problem a little more concrete, say your fasting blood glucose is 80, then you eat something and an hour later it goes to 120, and an hour after that it's back down to 90. How much blood glucose can I attribute to that food? What if I eat the same food on other day, but this time my blood glucose shoots to 150? What if I eat something else 30 minutes after eating this food so the data overlap? In theory, all of these situations can be modeled in a fairly simple probabilistic model.

Imports

I am using numpyro, and jax numpy instead of regular numpy. In most cases, jax is a direct replacement for numpy, but sometimes I need original numpy, so I keep that as onp.

#!pip install numpyro
import jax.numpy as np
from jax import random
from jax.scipy.special import logsumexp

import numpyro
import numpyro.distributions as dist
from numpyro.diagnostics import hpdi
from numpyro import handlers
from numpyro.infer import MCMC, NUTS

import numpy as onp
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt

import glob
import json
import io

from tqdm.autonotebook import tqdm
from contextlib import redirect_stdout

%config InlineBackend.figure_format='retina'

Some simple test data

Simple test data is always a good idea for any kind of probabilistic model. If I can't make the model work with the simplest possible data, then it's best to figure that out early, before starting to try to use real data. I believe McElreath advocates for something similar, though I can't find a citation for this.

To run inference with my model I need three 1D arrays:

  • bg_5min: blood glucose readings, measured every 5 minutes
  • hr_5min: heart rate, measured every 5 minutes
  • fd_5min: food eaten, measured every 5 minutes
def str_to_enum(arr):
    foods = sorted(set(arr))
    assert foods[0] == '', f"{foods[0]}"
    mapk = {food:n for n, food in enumerate(foods)}
    return np.array([mapk[val] for val in arr]), foods

bg_5min_test = np.array([91, 89, 92, 90, 90, 90, 88, 90, 93, 90,
                         90, 100, 108, 114.4, 119.52, 123.616, 126.89, 119.51, 113.61, 108.88, 105.11,
                         102.089, 99.67, 97.73, 96.18, 94.95, 93.96, 93.16, 92.53, 92.02, 91.62, 91.29, 91.03, 90.83, 90.66, 90.53,
                         90, 90, 90, 90, 90, 90, 90, 89, 90, 90, 89, 90, 90, 90, 90, 90, 90, 90, 90, 95, 95, 95, 95, 100])

_fd_5min_test = ['', '', '', '', '', 'food', '', '', '', '', '', '',
                 '', '', '', '', '', '', '', '', '', '', '', '',
                 '', '', '', '', '', '', '', '', '', '', '', '',
                 '', '', '', '', 'doof', '', '', '', '', '', '', '',
                 '', '', '', '', '', '', '', '', '', '', '', '', '']

fd_5min_test, fd_5min_test_key = str_to_enum(_fd_5min_test)
hr_5min_test = np.array([100] * len(bg_5min_test))

bg_5min_test = np.tile(bg_5min_test, 1)
fd_5min_test = np.tile(fd_5min_test, 1)
hr_5min_test = np.tile(hr_5min_test, 1)

print(f"Example test data (len {len(bg_5min_test)}):")
print("bg", bg_5min_test[:6])
print("heart rate", hr_5min_test[:6])
print("food as enum", fd_5min_test[:6])

assert len(bg_5min_test) == len(hr_5min_test) == len(fd_5min_test)
Example test data (len 60):
bg [91. 89. 92. 90. 90. 90.]
heart rate [100 100 100 100 100 100]
food as enum [0 0 0 0 0 2]

Read in real data as a DataFrame

I have already converted all my Apple Health data to csv. Here I process the data to include naive day and date (local time) — timezones are impossible otherwise.

df_complete = pd.DataFrame()

for f in glob.glob("HK*.csv"):
    _df = (pd.read_csv(f, sep=';', skiprows=1)
             .assign(date = lambda df: pd.to_datetime(df['startdate'].apply(lambda x:x[:-6]),
                                                      infer_datetime_format=True))
             .assign(day = lambda df: df['date'].dt.date,
                     time = lambda df: df['date'].dt.time)
          )

    if 'unit' not in _df:
        _df['unit'] = _df['type']
    df_complete = pd.concat([df_complete, _df[['type', 'sourcename', 'unit', 'day', 'time', 'date', 'value']]])

# clean up the names a bit and sort
df_complete = (df_complete.assign(type = lambda df: df['type'].str.split('TypeIdentifier', expand=True)[1])
                          .sort_values(['type', 'date']))
df_complete.sample(4)

I can remove a bunch of data from the complete DataFrame.

df = (df_complete
        .loc[lambda r: ~r.type.isin({"HeadphoneAudioExposure", "BodyMass", "UVExposure", "SleepAnalysis"})]
        .loc[lambda r: ~r.sourcename.isin({"Brian Naughton’s iPhone", "Strava"})]
        .loc[lambda r: r.date >= min(df_complete.loc[lambda r: r.sourcename == 'Connect'].date)]
        .assign(value = lambda df: df.value.astype(float)))
df.groupby(['type', 'sourcename']).head(1)
type sourcename unit day time date value
ActiveEnergyBurned Connect kcal 2018-07-08 21:00:00 2018-07-08 21:00:00 8.00
BasalEnergyBurned Connect kcal 2018-07-08 21:00:00 2018-07-08 21:00:00 2053.00
BloodGlucose Dexcom G6 mg/dL 2020-02-24 22:58:54 2020-02-24 22:58:54 111.00
DistanceWalkingRunning Connect mi 2018-07-08 21:00:00 2018-07-08 21:00:00 0.027
FlightsClimbed Connect count 2018-07-08 21:00:00 2018-07-08 21:00:00 0.00
HeartRate Connect count/min 2018-07-01 18:56:00 2018-07-01 18:56:00 60.00
RestingHeartRate Connect count/min 2019-09-07 00:00:00 2019-09-07 00:00:00 45.00
StepCount Connect count 2018-07-01 19:30:00 2018-07-01 19:30:00 33.00

I have to process my food data separately, since it doesn't come from Apple Health. The information was just in a Google Sheet.

df_food = (pd.read_csv("food_data.csv", sep='\t')
             .astype(dtype = {"date":"datetime64[ns]", "food": str})
             .assign(food = lambda df: df["food"].str.split(',', expand=True)[0] )
             .loc[lambda r: ~r.food.isin({"chocolate clusters", "cheese weirds", "skinny almonds", "asparagus"})]
             .sort_values("date")
            )
df_food.sample(4)
date food
2020-03-03 14:10:00 egg spinach on brioche
2020-03-11 18:52:00 lentil soup and bread
2020-02-27 06:10:00 bulletproof coffee
2020-03-02 06:09:00 bulletproof coffee

For inference purposes, I am rounding everything to the nearest 5 minutes. This creates the three arrays the model needs.

_df_food = (df_food
              .assign(rounded_date = lambda df: df['date'].dt.round('5min'))
              .set_index('rounded_date'))

_df_hr = (df.loc[lambda r: r.type == 'HeartRate']
            .assign(rounded_date = lambda df: df['date'].dt.round('5min'))
            .groupby('rounded_date')
            .mean())

def get_food_at_datetime(datetime):
    food = (_df_food
              .loc[datetime - pd.Timedelta(minutes=1) : datetime + pd.Timedelta(minutes=1)]
              ['food'])

    if any(food):
        return sorted(set(food))[0]
    else:
        return None

def get_hr_at_datetime(datetime):
    hr = (_df_hr
            .loc[datetime - pd.Timedelta(minutes=1) : datetime + pd.Timedelta(minutes=1)]
            ['value'])
    if any(hr):
        return sorted(set(hr))[0]
    else:
        return None


df_5min = \
(df.loc[lambda r: r.type == "BloodGlucose"]
   .loc[lambda r: r.date >= pd.datetime(2020, 2, 25)]
   .assign(rounded_date = lambda df: df['date'].dt.round('5min'),
           day = lambda df: df['date'].dt.date,
           time = lambda df: df['date'].dt.round('5min').dt.time,
           food = lambda df: df['rounded_date'].apply(get_food_at_datetime),
           hr = lambda df: df['rounded_date'].apply(get_hr_at_datetime)
          )
)


bg_5min_real = onp.array(df_5min.pivot_table(index='day', columns='time', values='value').interpolate('linear'))
hr_5min_real = onp.array(df_5min.pivot_table(index='day', columns='time', values='hr').interpolate('linear'))
fd_5min_real = onp.array(df_5min.fillna('').pivot_table(index='day', columns='time', values='food', aggfunc=lambda x:x).fillna('')).astype(str)
assert bg_5min_real.shape == hr_5min_real.shape == fd_5min_real.shape, "array size mismatch"

print(f"Real data (len {bg_5min_real.shape}):")
print("bg", bg_5min_real[0][:6])
print("hr", hr_5min_real[0][:6])
print("food", fd_5min_real[0][:6])
Real data (len (30, 288)):
bg [102. 112. 114. 113. 112. 111.]
hr [50.33333333 42.5        42.66666667 43.         42.66666667 42.5       ]
food ['' '' '' '' '' '']

numpyro model

The model is necessarily pretty simple, given the amount of data I have to work with. Note that I use "g" as a shorthand for mg/dL, which may be confusing. The priors are:

  • baseline_mu: I have a uniform prior on my "baseline" (fasting) mean blood glucose level. If I haven't eaten, I expect my blood glucose to be distributed around this value. This is a uniform prior from 80–110 mg/dL.
  • all_food_g_5min_mu: Each food deposits a certain number of mg/dL of sugar into my blood every 5 minutes. This is a uniform prior from 8–16 mg/dL per 5 minutes.
  • all_food_duration_mu: Each food only starts depositing glucose after a delay for digestion and absorption. This is a uniform prior from 45–100 minutes.
  • all_food_g_5min_mu: Food keeps depositing sugar into my blood for some time. Different foods may be quick spikes of glucose (sugary) or slow release (fatty). This is a uniform prior from 25–50 minutes.
  • regression_rate: The body regulates glucose by releasing insulin, which pulls blood glucose back to baseline at a certain rate. I allow for three different rates, depending on my heart rate zone. Using this simple model, the higher the blood glucose, the more it regresses, so it produces a somewhat bell-shaped curve.

Using uniform priors is really not best practices, but it makes it easier to keep values in a reasonable range. I did experiment with other priors, but settled on this after failing to keep my posteriors to reasonable values. As usual, explaining the priors to the model is like asking for a wish from an obtuse genie. This is a common problem with probabilistic models, at least for me.

The observed blood glucose is then a function that uses this approximate logic:

for t in all_5_minute_increments:
  blood_glucose[t] = baseline[t] # blood glucose with no food
  for food in foods:
    if t - time_eaten[food] > food_delays[food]: # then the food has started entering the bloodstream
      if t - (time_eaten[food] + food_delays[food]) < food_duration[food]: # then food is still entering the bloodstream
        blood_glucose[t] += food_g_5min[food]

  blood_glucose[t] = blood_glucose[t] - (blood_glucose[t] - baseline[t]) * regression_rate[heart_rate_zone_at_time(t)]

This model has some nice properties:

  • every food has only three parameters: the delay, the duration, and the mg/dL per 5 minutes. The total amount of glucose (loosely defined) can be easily calculate by multiplying food_duration * food_g_5_min.
  • the model only has two other parameters: the baseline glucose, and regression rate
MAX_TIMEPOINTS = 70 # effects are modeled 350 mins into the future only
DL = 288 # * 5 minutes = 1 day
SUB = '' # or a number, to subset
DAYS = [11]

%%time
def model(bg_5min, hr_5min, fd_5min, fd_5min_key):

    def zone(hr):
        if hr < 110: return "1"
        elif hr < 140: return "2"
        else: return "3"

    assert fd_5min_key[0] == '', f"{fd_5min_key[0]}"
    all_foods = fd_5min_key[1:]

    # ----------------------------------------------

    baseline_mu = numpyro.sample('baseline_mu', dist.Uniform(80, 110))

    all_food_g_5min_mu = numpyro.sample("all_food_g_5min_mu", dist.Uniform(8, 16))
    food_g_5mins = {food : numpyro.sample(f"food_g_5min_{food}",
                                          dist.Uniform(all_food_g_5min_mu-8, all_food_g_5min_mu+8))
                    for food in all_foods}

    all_food_duration_mu = numpyro.sample(f"all_food_duration_mu", dist.Uniform(5, 10))
    food_durations = {food : numpyro.sample(f"food_duration_{food}",
                                            dist.Uniform(all_food_duration_mu-3, all_food_duration_mu+3))
                      for food in all_foods}

    regression_rate = {zone : numpyro.sample(f"regression_rate_{zone}", dist.Uniform(0.7, 0.96))
                       for zone in "123"}

    all_food_delay_mu = numpyro.sample("all_food_delay_mu", dist.Uniform(9, 20))
    food_delays = {food : numpyro.sample(f"food_delay_{food}", dist.Uniform(all_food_delay_mu-4, all_food_delay_mu+4))
                   for food in all_foods}

    food_g_totals = {food : numpyro.deterministic(f"food_g_total_{food}", food_durations[food] * food_g_5mins[food])
                     for food in all_foods}

    result_mu = [0 for _ in range(len(bg_5min))]
    for j in range(1, len(result_mu)):
        if fd_5min[j] == 0:
            continue

        food = fd_5min_key[fd_5min[j]]
        for _j in range(min(MAX_TIMEPOINTS, len(result_mu)-j)):
            result_mu[j+_j] = numpyro.deterministic(
                f"add_{food}_{j}_{_j}",
                (result_mu[j+_j-1] * regression_rate[zone(hr_5min[j+_j])]
                 + np.where(_j > food_delays[food],
                            np.where(food_delays[food] + food_durations[food] - _j < 0, 0, 1) * food_g_5mins[food],
                            0)
                )
            )


    for j in range(len(result_mu)):
        result_mu[j] = numpyro.deterministic(f"result_mu_{j}", baseline_mu + result_mu[j])

    # observations
    obs = [numpyro.sample(f'result_{i}', dist.Normal(result_mu[i], 5), obs=bg_5min[i])
           for i in range(len(result_mu))]

def make_model(bg_5min, hr_5min, fd_5min, fd_5min_key):
    from functools import partial
    return partial(model, bg_5min, hr_5min, fd_5min, fd_5min_key)

Sample and store results

Because the process is slow, I usually only run one day at at time, and store all the samples in json files. It would be much better to run inference on all days simultaneously, especially since I have data for the same food on different days. Unfortunately, this turned out to take way too long.

assert len(bg_5min_real.ravel())/DL == 30, len(bg_5min_real.ravel())/DL

for d in DAYS:
    print(f"day {d}")
    bg_5min = bg_5min_real.ravel()[int(DL*d):int(DL*(d+1))]
    hr_5min = hr_5min_real.ravel()[int(DL*d):int(DL*(d+1))]
    _fd_5min = fd_5min_real.ravel()[int(DL*d):int(DL*(d+1))]
    fd_5min, fd_5min_key = str_to_enum(_fd_5min)

    #
    # subset it for testing
    #
    if SUB != '':
        bg_5min, hr_5min, fd_5min = bg_5min[-SUB*3:-SUB*1], hr_5min[-SUB*3:-SUB*1], fd_5min[-SUB*3:-SUB*1]

    open(f"bg_5min_{d}{SUB}.json",'w').write(json.dumps(list(bg_5min)))
    open(f"hr_5min_{d}{SUB}.json",'w').write(json.dumps(list(hr_5min)))
    open(f"fd_5min_key_{d}{SUB}.json",'w').write(json.dumps(fd_5min_key))
    open(f"fd_5min_{d}{SUB}.json",'w').write(json.dumps([int(ix) for ix in fd_5min]))

    rng_key = random.PRNGKey(0)
    rng_key, rng_key_ = random.split(rng_key)
    num_warmup, num_samples = 500, 5500

    kernel = NUTS(make_model(bg_5min, hr_5min, fd_5min, fd_5min_key))
    mcmc = MCMC(kernel, num_warmup, num_samples)
    mcmc.run(rng_key_)
    mcmc.print_summary()
    samples_1 = mcmc.get_samples()
    print(d, sorted([(float(samples_1[food_type].mean()), food_type) for food_type in samples_1 if "total" in food_type], reverse=True))
    print(d, sorted([(float(samples_1[food_type].mean()), food_type) for food_type in samples_1 if "delay" in food_type], reverse=True))
    print(d, sorted([(float(samples_1[food_type].mean()), food_type) for food_type in samples_1 if "duration" in food_type], reverse=True))

    open(f"samples_{d}{SUB}.json", 'w').write(json.dumps({k:[round(float(v),4) for v in vs] for k,vs in samples_1.items()}))

    print_io = io.StringIO()
    with redirect_stdout(print_io):
        mcmc.print_summary()
    open(f"summary_{d}{SUB}.json", 'w').write(print_io.getvalue())
day 11


sample: 100%|██████████| 6000/6000 [48:58<00:00,  2.04it/s, 1023 steps of size 3.29e-04. acc. prob=0.81]



                                             mean       std    median      5.0%     95.0%     n_eff     r_hat
                      all_food_delay_mu     10.39      0.29     10.40      9.95     10.86     10.32      1.04
                   all_food_duration_mu      9.22      0.20      9.22      8.93      9.51      2.77      2.28
                     all_food_g_5min_mu      9.81      0.37      9.78      9.16     10.32      5.55      1.44
                            baseline_mu     89.49      0.37     89.47     88.89     90.09      9.89      1.00
          food_delay_bulletproof coffee      7.59      0.86      7.51      6.37      8.98     23.44      1.02
      food_delay_egg spinach on brioche     13.75      0.41     13.92     13.04     14.26      4.81      1.08
              food_delay_milk chocolate      6.70      0.21      6.74      6.38      7.00     17.39      1.00
     food_delay_vegetables and potatoes     12.03      0.93     12.15     10.52     13.44     16.36      1.08
       food_duration_bulletproof coffee      9.92      1.36      9.72      8.27     12.09      2.72      2.68
   food_duration_egg spinach on brioche     11.42      0.34     11.42     10.94     11.93      2.51      2.59
           food_duration_milk chocolate      7.01      0.22      6.98      6.58      7.34     16.01      1.00
  food_duration_vegetables and potatoes      8.56      1.40      8.52      6.55     10.43      4.89      1.37
         food_g_5min_bulletproof coffee      2.45      0.38      2.41      1.86      3.01      7.07      1.58
     food_g_5min_egg spinach on brioche      7.38      0.33      7.39      6.86      7.92      9.65      1.36
             food_g_5min_milk chocolate     11.38      0.38     11.40     10.70     11.94     21.69      1.06
    food_g_5min_vegetables and potatoes     14.46      2.99     15.55      9.90     18.17     10.31      1.17
                      regression_rate_1      0.93      0.00      0.93      0.92      0.93     24.68      1.04
                      regression_rate_2      0.74      0.03      0.73      0.71      0.78      3.68      1.69
                      regression_rate_3      0.78      0.03      0.77      0.73      0.82      4.28      1.60

Plot results

Plotting the data in a sensible way is a little bit tricky. Hopefully the plot is mostly self-explanatory, but note that baseline is set to zero, the y-axis is on a log scale, and the x-axis is midnight to midnight, in 5 minute increments.

def get_plot_sums(samples, bg_5min, hr_5min, fd_5min_key):
    samples = {k: onp.array(v) for k,v in samples.items()}
    bg_5min = onp.array(bg_5min)
    hr_5min = onp.array(hr_5min)

    means = []
    for n in tqdm(range(len(bg_5min))):
        means.append({"baseline": np.mean(samples[f"baseline_mu"])})
        for _j in range(MAX_TIMEPOINTS):
            for food in fd_5min_key:
                if f"add_{food}_{n-_j}_{_j}" in samples:
                    means[n][food] = np.mean(samples[f"add_{food}_{n-_j}_{_j}"])

    plot_data = []
    ordered_foods = ['baseline'] + sorted({k for d in means for k, v in d.items() if k != 'baseline'})
    for bg, d in zip(bg_5min, means):
        tot = 0
        plot_data.append([])
        for food in ordered_foods:
            if food in d and d[food] > 0.1:
                plot_data[-1].append(round(tot + float(d[food]), 2))
                if food != 'baseline':
                    tot += float(d[food])
            else:
                plot_data[-1].append(0)
    print("ordered_foods", ordered_foods)
    return pd.DataFrame(plot_data, columns=ordered_foods)
for n in DAYS:
    print(open(f"summary_{n}{SUB}.json").read())
    samples = json.load(open(f"samples_{n}{SUB}.json"))
    bg_5min = json.load(open(f"bg_5min_{n}{SUB}.json"))
    hr_5min = json.load(open(f"hr_5min_{n}{SUB}.json"))
    fd_5min = json.load(open(f"fd_5min_{n}{SUB}.json"))
    fd_5min_key = json.load(open(f"fd_5min_key_{n}.json"))

    print("fd_5min_key", fd_5min_key)
    plot_data = get_plot_sums(samples, bg_5min, hr_5min, fd_5min_key)
    plot_data['real'] = bg_5min
    plot_data['heartrate'] = [hr for hr in hr_5min]

    baseline = {int(i) for i in plot_data['baseline']}
    assert len(baseline) == 1
    baseline = list(baseline)[0]
    log_plot_data = plot_data.drop('baseline', axis=1)
    log_plot_data['real'] -= plot_data['baseline']

    display(log_plot_data);

    f, ax = plt.subplots(figsize=(16,6));
    ax = sns.lineplot(data=log_plot_data.rename(columns={"heartrate":"hr"})
                                        .drop('hr', axis=1),
                      dashes=False,
                      ax=ax);
    ax.set_ylim(-50, 120);
    ax.set_title(f"day {n} baseline {baseline}");
    ax.lines[len(log_plot_data.columns)-2].set_linestyle("--");
    f.savefig(f'food_model_{n}{SUB}.png');
    f.savefig(f'food_model_{n}{SUB}.svg');

png

Total blood glucose

Finally, I can rank each food by its total blood glucose contribution. Usually, this matches expectations, and occasionally it's way off. Here it looks pretty reasonable: the largest meal by far contributed the largest total amount, and the milk chocolate had the shortest delay.

from pprint import pprint
print("### total grams of glucose")
pprint(sorted([(round(float(samples_1[food_type].mean()), 2), food_type) for food_type in samples_1 if "total" in food_type], reverse=True))
print("### food delays")
pprint(sorted([(round(float(samples_1[food_type].mean()), 2), food_type) for food_type in samples_1 if "delay" in food_type], reverse=True))
print("### food durations")
pprint(sorted([(round(float(samples_1[food_type].mean()), 2), food_type) for food_type in samples_1 if "duration" in food_type], reverse=True))
### total blood glucose contribution
[(123.52, 'food_g_total_vegetables and potatoes'),
 (84.24, 'food_g_total_egg spinach on brioche'),
 (79.69, 'food_g_total_milk chocolate'),
 (23.88, 'food_g_total_bulletproof coffee')]
### food delays
[(13.75, 'food_delay_egg spinach on brioche'),
 (12.03, 'food_delay_vegetables and potatoes'),
 (10.39, 'all_food_delay_mu'),
 (7.59, 'food_delay_bulletproof coffee'),
 (6.7, 'food_delay_milk chocolate')]
### food durations
[(11.42, 'food_duration_egg spinach on brioche'),
 (9.92, 'food_duration_bulletproof coffee'),
 (9.22, 'all_food_duration_mu'),
 (8.56, 'food_duration_vegetables and potatoes'),
 (7.01, 'food_duration_milk chocolate')]

Conclusions

It sort of works. For example, day 11 looks pretty good! Unfortunately, when it's wrong, it can be pretty far off, so I am not sure about the utility without imbuing the model with more robustness. (For example, note that the r_hat is often out of range.)

The parameterization of the model is pretty unsatisfying. To make the model hang together I have to constrain the priors to specific ranges. I could probably grid-search the whole space pretty quickly to get to a maximum posterior here.

I think if I had a year of data instead of a month, this could get pretty interesting. I would also need to figure out how to sample quicker, which should be possible given the simplicity of the model. With the current implementation I am unable to sample even a month of data at once.

Finally, here is the complete set of plots I generated, to see the range of results.

png

Comment

Oxford Nanopore (ONT) sells an amazing, inexpensive sequencer called the MinION. It's an unusual device in that the sequencing "flowcells" use protein pores. Unlike a silicon chip, the pores are sensitive to their environment (especially heat) and can get damaged or degraded and stop working.

When you receive a flowcell from ONT, only a fraction of the possible 2048 pores are active, perhaps 800–1500. Because of how the flowcell works, you start a run using zero or one pores from each of the 512 "channels" (each of which contains 4 pores).

As the pores interact with the DNA and other stuff in your solution, they can get gummed up and stop working. It's not unusual to start a run with 400 pores and end a few hours later with half that many still active. It's also not unusual to put a flowcell in the fridge and find it has 20% fewer pores when you take it out. (Conversely, pores can come back to life after a bit of a rest.)

All told, this means it's quite difficult to tell how much sequence you can expect from a run. If you want to sequence a plasmid at 100X, and you are starting with 400 pores, will you get enough data in two hours? That's the kind of question we want an answer to.

import pymc3 as pm
import numpy as np
import scipy as sp
import theano.tensor as T

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from functools import partial
from IPython.display import display, HTML, Image

For testing purposes — and because I don't have nearly enough real data — I will simulate some data.

In this toy example, DNA of length 1000+/-50 bases is sequenced for 1000, 20000 or 40000 seconds at 250bp/s. After 100 reads, a pore gets blocked and can no longer sequence. The device I am simulating has only 3 pores (where a real MinION would have 512).

datatypes = [
    ('total_num_reads', np.int),        ('total_dna_read', np.int),
    ('num_minutes', np.float),          ('num_active_pores_start', np.int),
    ('num_active_pores_end', np.int),   ('mean_input_dna_len_estimate', np.float)]

MAX_PORES = 3
AXIS_RUNS, AXIS_PORES = 0, 1

data = np.array([
  ( MAX_PORES*9,   MAX_PORES*  9*1000*.95,    1000.0/60,  MAX_PORES,  MAX_PORES,  1000.0),
  ( MAX_PORES*10,  MAX_PORES* 10*1000*1.05,   1000.0/60,  MAX_PORES,  MAX_PORES,  1000.0),
  ( MAX_PORES*10,  MAX_PORES* 10*1000*.95,    1000.0/60,  MAX_PORES,  MAX_PORES,  1000.0),
  ( MAX_PORES*11,  MAX_PORES* 11*1000*1.05,   1000.0/60,  MAX_PORES,  MAX_PORES,  1000.0),
  ( MAX_PORES*100, MAX_PORES*100*1000*.95,   20000.0/60,  MAX_PORES,          0,  1000.0),
  ( MAX_PORES*100, MAX_PORES*100*1000*1.05,  20000.0/60,  MAX_PORES,          0,  1000.0),
  ( MAX_PORES*100, MAX_PORES*100*1000*.95,   40000.0/60,  MAX_PORES,          0,  1000.0),
  ( MAX_PORES*100, MAX_PORES*100*1000*1.05,  40000.0/60,  MAX_PORES,          0,  1000.0),
], dtype=datatypes)

MinION Simulator

To help figure out how long to leave our sequencer running, I made a MinION simulator in a jupyter notebook. I then turned the notebook into an app using dappled.io, an awesome new service that can turn jupyter notebooks into apps running on the cloud (including AWS lambda). Here is the dappled nanopore simulator app / notebook.

The model is simple: a pore reads DNA until it's completely read, then with low probability it gets blocked and can no longer sequence. I've been fiddling with the parameters to give a reasonable result for our runs, but we don't have too much data yet, so it's not terribly accurate. (The model itself may also be inaccurate, since how the pore gets blocked is not really discussed by ONT.)

Nevertheless, the simulator is a useful tool for ballparking how much sequence we should expect from a run.

nanopore simulation

Hierarchical Bayesian Model

Here's where things get more complicated...

In theory, given some data from real MinION runs, we should be able to learn the parameters for a model that would enable us to predict how much data we would get from a new run. Like many problems, this is a good fit for Bayesian analysis, since we have data and we want to learn the most appropriate model given the data.

For each run, I need to know:

  • what length DNA I think I started with
  • how long the run was in minutes
  • how many reads I got
  • how much sequence I got
  • the number of active pores at the start and at the end of a run

I'll use PyMC3 for this problem. First, I need to specify the model.

Input DNA

The input DNA depends a lot on its preparation. I believe the distribution of input DNA lengths could be:

  • exponential if it is genomic DNA breaking randomly
  • normal with a small variance if it is from a plasmid
  • normal but sharply truncated if it is cut out of a gel

The length of input DNA is different to the length of reads produced by the sequencer, which is affected by breakage, capture biases, secondary structure, etc. The relationship between input DNA length and read length could be learned. We could get arbitrarily complex here: for example, given enough data, we could include a mixture model for different DNA types (genomic vs plasmid).

A simple distribution with small variance but fat tails seems reasonable here. I don't have much idea what the standard deviation should be, so I'll make it a fraction of the mean.

with pm.Model() as model:
    mean_input_dna_len = pm.StudentT('mean_input_dna_len', nu=3, mu=data['mean_input_dna_len_estimate'],
                                     sd=data['mean_input_dna_len_estimate']/5, shape=data.shape[0])

This is the first PyMC3 code, so here are some notes on what's going on:

  • the context must always specify the pymc3 model
  • Absent better ideas, I'll generally be using a T distribution instead of a normal, as MacKay recommends. This distribution should be truncated at zero, but I can't do this for multidimensional data because of what I think is a bug in pymc3.
  • I am not modeling the input_dna_len here, just the mean of that distribution. As I mention above, the distribution of DNA lengths could be modeled several ways, but I currently only need the mean in my model.

the shape parameter

In this experiment, We will encounter three different kinds of distribution shape:

  • shape=1: this is a scalar value. For example, we might think the speed of the nanopore machine is the same for all runs.
  • shape=data.shape[0]: data.shape[0] is the number of runs. For example, the length of DNA in solution varies from run to run.
  • shape=(data.shape[0],MAX_NUM_PORES): we estimate a value for every pore in every run. For example, how many reads until a pore gets blocked? This varies by run, but needs to be modeled per pore too.

R9 Read Speed

I know that the R9 pore is supposed to read at 250 bases per second. I believe the mean could be a little bit more or less than that, and I believe that all flowcells and devices should be about the same speed, therefore all runs will sample from the same distribution of mean_read_speed.

Because this is a scalar, pymc3 lets me set a lower bound of 0 bases per second. (Interestingly, unlike DNA length, the speed could technically be negative, since the voltage across the pore can be positive or negative, as exploited by Read Until...)

Truncated0T1D = pm.Bound(pm.StudentT, lower=0)
with pm.Model() as model:
    mean_read_speed = Truncated0T1D('mean_read_speed', nu=3, mu=250, sd=10, shape=data.shape)

Capturing DNA

DNA is flopping around in solution, randomly entering pores now and then. How long will a pore have to wait for DNA? I have a very vague idea that it should take around a minute but this is basically an unknown that the sampler should figure out.

Note again that I am using the mean time only, not the distribution of times. The actual distribution of times to capture DNA would likely be distributed by an exponential distribution (waiting time for a Poisson process).

Hierarchical model

This is the first time I am using a hierarchical/multilevel model. For more on what this means, see this example from pymc3 or Andrew Gelman's books.

There are three options for modeling mean_time_to_capture_dna: (a) it's the same for each run (e.g., everybody us using the same recommended DNA concentration) (b) it's independent for each run (c) each run has a different mean, but the means are drawn from the same distribution (and probably should not be too different).

Image("hierarchical_pymc3.png")
nanopore simulation Diagram taken from PyMC3 dev `twiecki's blog <http://twiecki.github.io/blog/2014/03/17/bayesian-glms-3/>`__
with pm.Model() as model:
    prior_mean_time_to_capture_dna = Truncated0T1D('prior_mean_time_to_capture_dna', nu=3, mu=60, sd=30)
    mean_time_to_capture_dna = pm.StudentT('mean_time_to_capture_dna', nu=3, mu=prior_mean_time_to_capture_dna,
                                           sd=prior_mean_time_to_capture_dna/10, shape=data.shape)

Reading DNA

I can use pymc3's Deterministic type to calculate how long a pore spends reading a chunk of DNA, on average. That's just a division, but note it uses theano's true_div function instead of a regular division. This is because neither value is a number; they are random variables. Theano will calculate this as it's being sampled. (A simple "a/b" also works, but I like to keep in mind what's being done by theano if possible.)

with pm.Model() as model:
    mean_time_to_read_dna = pm.Deterministic('mean_time_to_read_dna', T.true_div(mean_input_dna_len, mean_read_speed))

Then each pore can do how many reads in this run? I have to be a bit careful to specify that I mean the number of reads possible per pore.

with pm.Model() as model:
    num_reads_possible_in_time_per_pore = pm.Deterministic('num_reads_possible_in_time_per_pore',
        T.true_div(num_seconds_run, T.add(mean_time_to_capture_dna, mean_time_to_read_dna)))

Blocking Pores

In my model, after each read ends, a pore can get blocked, and once it's blocked it does not become unblocked. I believe about one in a hundred times, a pore will be blocked after if finishes sequencing DNA. If it were more than that, sequencing would be difficult, but it could be much less.

We can think of the Beta distribution as modeling coin-flips where the probability of heads (pore getting blocked) is 1/100.

with pm.Model() as model:
    prior_p_pore_blocked_a, prior_p_pore_blocked_b, = 1, 99
    p_pore_blocked = pm.Beta('p_pore_blocked', alpha=prior_p_pore_blocked_a, beta=prior_p_pore_blocked_b)

Then, per pore, the number of reads before blockage is distributed as a geometric distribution (the distribution of the number of tails you will flip before you flip a head). I have to approximate this with an exponential distribution — the continuous version of a geometric — because ADVI (see below) requires continuous distributions. I don't think it makes an appreciable difference.

Here the shape parameter is the number of runs x number of pores since here I need a value for every pore.

with pm.Model() as model:
    num_reads_before_blockage = pm.Exponential('num_reads_before_blockage', lam=p_pore_blocked,
        shape=(data.shape[0], MAX_PORES))

Constraints

Here things get a little more complicated. It is not possible to have two random variables x and y, set x + y = data, and sample values of x and y.

testdata = np.random.normal(loc=10,size=100) + np.random.normal(loc=1,size=100)
with pm.Model() as testmodel:
    x = pm.Normal("x", mu=0, sd=10)
    y = pm.Normal("y", mu=0, sd=1)
    # This does not work
    #z = pm.Deterministic(x + y, observed=data)
    # This does work
    z = pm.Normal('x+y', mu=x+y, sd=1, observed=testdata)
    trace = pm.sample(1000, pm.Metropolis())

I was surprised by this at first but it makes sense. This process is more like a regression where you minimize error than a perfectly fixed constraint. The smaller the standard deviation, the more you penalize deviations from the constraint. However, you need some slack so that it's always possible to estimate a logp for the data given the model. If the standard deviation goes too low, you end up with numerical problems (e.g., nans). Unfortunately, I do not know of a reasonable way to set this value.

I encode constraints in my model in a similar way: First I define a Laplace distribution in theano to act as my likelihood. Why Laplace instead of Normal? It returned nans less often...

Using your own DensityDist allows you to use any function as a likelihood. I use DensityDist just to have more control over the likelihood, and to differentiate from a true Normal/Laplacian distribution. This can be finicky, so I've had to spend a bunch of time tweaking this.

def T_Laplace(val, obs, b=1):
    return T.log(T.mul(1/(2*b), T.exp(T.neg(T.true_div(T.abs_(T.sub(val, obs)), b)))))

Constraint #1

The total DNA I have read must be the product of mean_input_dna_len and total_num_reads. mean_input_dna_len can be different to my estimate.

This is a bit redundant since I know total_num_reads and total_dna_read exactly. In a slightly more complex model we would have different mean_read_length and mean_input_dna_len. Here it amounts to just calculating mean_input_dna_len as total_dna_read / total_num_reads.

def apply_mul_constraint(f1, f2, observed):
    b = 1 # 1 fails with NUTS (logp=-inf), 10/100 fails with NUTS too (not positive definite)
    return T_Laplace(T.mul(f1,f2), observed, b)

with pm.Model() as model:
    total_dna_read_constraint = partial(apply_mul_constraint, mean_input_dna_len, data['total_num_reads'])
    constrain_total_dna_read = pm.DensityDist('constrain_total_dna_read', total_dna_read_constraint, observed=data['total_dna_read'])

Constraint #2

The number of reads per pore is whichever is lower:

  • the number of reads the pore could manage in the length of the run
  • the number of reads it manages before getting blocked

To calculate this, I compare num_reads_before_blockage and num_reads_possible_in_time_per_pore.

First, I use T.tile to replicate num_reads_possible_in_time_per_pore for all pores. That turns it from an array of length #runs to a matrix of shape #runs x #pores (this should use broadcasting but I wasn't sure it was working properly...)

Then I take the minimum value of these two arrays (T.lt) and if the minimum value is the number of reads before blockage (T.switch(T.lt(f1,f2),0,1)) then that pore is blocked, otherwise it is active. I sum these 0/1s over all pores (axis=AXIS_PORES) to get a count of the number of active pores for each run.

The value of this count is constrained to be equal to data['num_active_pores_end'].

def apply_count_constraint(num_reads_before_blockage, num_reads_possible_in_time_broadcast, observed):
    b = 1
    num_active_pores = T.sum(T.switch(T.lt(num_reads_before_blockage, num_reads_possible_in_time_broadcast),0,1),axis=AXIS_PORES)
    return T_Laplace(num_active_pores, observed, b)

with pm.Model() as model:
    num_reads_possible_in_time_broadcast = T.tile(num_reads_possible_in_time_per_pore, (MAX_PORES,1)).T
    num_active_pores_end_constraint = partial(apply_count_constraint, num_reads_before_blockage, num_reads_possible_in_time_broadcast)
    constrain_num_active_pores_end = pm.DensityDist('constrain_num_active_pores_end', num_active_pores_end_constraint, observed=data['num_active_pores_end'])

Constraint #3

Using the same matrix num_reads_possible_in_time_broadcast this time I sum the total number of reads from each pore in a run. I simply sum the minimum value from each pore: either the number before blockage occurs or the total number possible.

def apply_minsum_constraint(num_reads_before_blockage, num_reads_possible_in_time_broadcast, observed):
    b = 1 # b=1 fails with ADVI and >100 pores (nan)
    min_reads_per_run = T.sum(T.minimum(num_reads_before_blockage, num_reads_possible_in_time_broadcast),axis=AXIS_PORES)
    return T_Laplace(min_reads_per_run, observed, b)

with pm.Model() as model:
    total_num_reads_constraint = partial(apply_minsum_constraint, num_reads_before_blockage, num_reads_possible_in_time_broadcast)
    constrain_total_num_reads = pm.DensityDist('constrain_total_num_reads', total_num_reads_constraint, observed=data['total_num_reads'])

Sampling

There are three principal ways to sample in PyMC3:

  • Metropolis-Hastings: the simplest sampler, generally works fine on simpler problems
  • NUTS: more efficient than Metropolis, but I've found it to be slow and tricky to get to work
  • ADVI: this is "variational inference", the new, fast way to estimate a posterior. This seems to work great, though as I mention above, it needs continuous distributions only.

I used ADVI most of the time in this project, since NUTS was too slow and had more numerical issues than ADVI, and the ADVI results seemed more sensible than Metropolis.

with pm.Model() as model:
    v_params = pm.variational.advi(n=500000, random_seed=1)
    trace = pm.variational.sample_vp(v_params, draws=5000)

Results

In my toy model: there are 8 runs in total; each device has 3 pores; it takes about 4 seconds to sequence a DNA molecule (~1000 bases / 250 bases per second) and 96 seconds to capture a DNA molecule (so 100 seconds in total to read one DNA molecule).

In my 8 runs, I expect that [10, 10, 10, 10, 200, 200, 400, 400] reads are possible. However, I expect that runs 4, 5, 6, 7 will have a blockage at 100 reads.

Finally, I expect there to be 3 pores remaining in the first four runs, and 0 pores remaining in the last four.

ADVI Results

mean_input_dna_len__0                    950.0 (+/- 0.1) # All the input_dna_lens are ok.
mean_input_dna_len__1                   1050.0 (+/- 0.1) # This is a simple calculation since I know
mean_input_dna_len__2                    950.0 (+/- 0.1) # how many reads there are and how long the reads are
mean_input_dna_len__3                   1050.0 (+/- 0.1)
mean_input_dna_len__4                    950.0 (+/- 0.1)
mean_input_dna_len__5                   1049.9 (+/- 0.1)
mean_input_dna_len__6                    950.0 (+/- 0.1)
mean_input_dna_len__7                   1050.0 (+/- 0.1)
num_reads_possible_in_time_per_pore__0    10.1 (+/- 1) # The number of reads possible is also a simple calculation
num_reads_possible_in_time_per_pore__1    10.3 (+/- 1) # It depends on the speed of the sequencer and the length of DNA
num_reads_possible_in_time_per_pore__2    10.2 (+/- 1)
num_reads_possible_in_time_per_pore__3    10.5 (+/- 1)
num_reads_possible_in_time_per_pore__4   210.9 (+/- 36)
num_reads_possible_in_time_per_pore__5   207.4 (+/- 35)
num_reads_possible_in_time_per_pore__6   419.8 (+/- 67)
num_reads_possible_in_time_per_pore__7   413.2 (+/- 66)
num_reads_before_blockage_per_run__0     501.3 (+/- 557) # The total number of reads before blockage per run
num_reads_before_blockage_per_run__1     509.8 (+/- 543) # is highly variable when there is no blockage (runs 0-3).
num_reads_before_blockage_per_run__2     501.9 (+/- 512)
num_reads_before_blockage_per_run__3     502.4 (+/- 591)
num_reads_before_blockage_per_run__4     297.2 (+/- 39)  # When there is blockage (runs 4-7), then it's much less variable
num_reads_before_blockage_per_run__5     298.7 (+/- 38)
num_reads_before_blockage_per_run__6     299.6 (+/- 38)
num_reads_before_blockage_per_run__7     301.2 (+/- 38)
num_active_pores_per_run__0                2.8 (+/- 0.4) # The number of active pores per run is estimated correctly
num_active_pores_per_run__1                2.8 (+/- 0.4) # as we expect for a value we've inputted
num_active_pores_per_run__2                2.8 (+/- 0.4)
num_active_pores_per_run__3                2.8 (+/- 0.3)
num_active_pores_per_run__4                0.0 (+/- 0.1)
num_active_pores_per_run__5                0.0 (+/- 0.1)
num_active_pores_per_run__6                0.0 (+/- 0.0)
num_active_pores_per_run__7                0.0 (+/- 0.0)

ADVI Plots

We can plot the values above too. No real surprises here, though num_active_pores_per_run doesn't show the full distribution for some reason.

with pm.Model() as model:
    pm.traceplot(trace[-1000:],
                 figsize=(12,len(trace.varnames)*1.5),
                 lines={k: v['mean'] for k, v in pm.df_summary(trace[-1000:]).iterrows()})
nanopore simulation

Conclusions

It's a pretty complex model. Is it actually useful? I think it's almost useful, but in its current form it suffers from a lack of robustness.

The results I get are sensitive to changes in several independent areas:

  • the sampler/inference method: for example, Metropolis returns different answers to ADVI
  • the constraint parameters: changing the "b" parameter can lead to numerical problems
  • the priors I selected: changes in my priors that I do not perceive as important can lead to different results (more data would help here)

A big problem with my current model is that scaling it up to 512 pores seems to be difficult numerically. Metropolis just fails, I think because it can't sample efficiently; NUTS fails for reasons I don't understand (it throws an error about a positive definite matrix); ADVI works best, but starts to get nans as the number of pores grows, unless I loosen the constraints. It's possible that I should need to begin with loose constraints and tighten them over time.

Finally, the model currently expects the same number of pores to be available for every run. I haven't addressed that yet, though I think it should be pretty straightforward. There may be nice theano trick I am overlooking.

More Conclusions

Without ADVI, I think I would have failed to get an answer here. I'm pretty ignorant of how it works, but it seems like a significant advance and I'll definitely use it again. In fact, it would be interesting to try applying Edward, a variational inference toolkit with PyMC3 support, to this problem (this would mean swapping theano for tensorflow).

The process of modeling your data with a PyMC3/Stan/Edward model forces you to think a lot about what is really going on with your data and model (and even after spending quite a bit of time on it, my model still needs quite a bit of work to be more than a toy...) When your model has computational problems, as I had several times, it often means the model wasn't described correctly (Gelman's folk theorem).

Although it is still a difficult, technical process, I'm excited about this method of doing science. It seems like the right way to tackle a lot of problems. Maybe with advances like ADVI and theano/tensorflow; groups like the Blei and Gelman labs developing modeling tools like PyMC3, Stan and Edward; and experts like Kruschke and the Stan team creating robust models for us to copy, it will become more common.

Comment

This github repository has all of the exercises from Kruschke's Doing Bayesian Data Analysis book translated from R to Python/PyMC3.

I transferred the code into a giant IPython notebook.

Comment
Brian Naughton | Sun 16 November 2014 | data | data pymc3

How to make a Gaussian Mixture Model in PyMC3

Read More

Boolean Biotech © Brian Naughton Powered by Pelican and Twitter Bootstrap. Icons by Font Awesome and Font Awesome More