Skip to content

williambdean/conjugate

Repository files navigation

Conjugate Models

Ruff Tests PyPI version docs codecov

Bayesian conjugate models in Python

Installation

pip install conjugate-models

Features

Supported Models

Many likelihoods are supported including

  • Bernoulli / Binomial
  • Categorical / Multinomial
  • Poisson
  • Normal (including linear regression)
  • and many more

Basic Usage

Working with Pre-processed Data

  1. Define prior distribution from distributions module
  2. Pass data and prior into model from models modules
  3. Analytics with posterior and posterior predictive distributions
from conjugate.distributions import Beta, BetaBinomial
from conjugate.models import binomial_beta, binomial_beta_predictive

# Observed Data (sufficient statistics)
x = 4  # successes
N = 10 # trials

# Analytics
prior = Beta(1, 1)
prior_predictive: BetaBinomial = binomial_beta_predictive(n=N, distribution=prior)

posterior: Beta = binomial_beta(n=N, x=x, prior=prior)
posterior_predictive: BetaBinomial = binomial_beta_predictive(
    n=N, distribution=posterior
)

Working with Raw Observational Data

For raw data, use helper functions from the helpers module to extract sufficient statistics:

import numpy as np
from conjugate.distributions import Beta
from conjugate.models import binomial_beta
from conjugate.helpers import bernoulli_beta_inputs

# Raw observational data - individual trial outcomes
raw_data = [1, 0, 1, 1, 0, 1, 0, 1, 1, 0]  # success/failure per trial

# Extract sufficient statistics automatically
inputs = bernoulli_beta_inputs(raw_data)
print(inputs)  # {'x': 6, 'n': 10} - 6 successes in 10 trials

# Use with conjugate model
prior = Beta(1, 1)
posterior = binomial_beta(prior=prior, **inputs)

Common Helper Function Patterns

from conjugate.helpers import (
    poisson_gamma_inputs,      # For count data
    normal_known_variance_inputs,  # For continuous measurements
    exponential_gamma_inputs,  # For time-between-events data
    multinomial_dirichlet_inputs,  # For categorical data
)

# Count data (e.g., website visits per day)
count_data = [5, 3, 8, 2, 6, 4, 7, 1, 9, 3]
inputs = poisson_gamma_inputs(count_data)
# Returns: {'x': sum(count_data), 'n': len(count_data)}

# Continuous measurements with known variance
measurements = [2.3, 1.9, 2.7, 2.1, 2.5]
variance = 0.5
inputs = normal_known_variance_inputs(measurements, variance=variance)
# Returns: {'x_mean': mean(measurements), 'n': len(measurements), 'variance': variance}

# Time between events (e.g., customer arrivals)
wait_times = [3.2, 1.8, 4.1, 2.7, 3.9]
inputs = exponential_gamma_inputs(wait_times)
# Returns: {'x': sum(wait_times), 'n': len(wait_times)}

# Categorical outcomes (e.g., survey responses A, B, C)
responses = ['A', 'B', 'A', 'C', 'B', 'A', 'B']
inputs = multinomial_dirichlet_inputs(responses)
# Returns: {'x': [3, 3, 1]} - counts for each category

All 50+ helper functions follow the same pattern: raw observations in → sufficient statistics out → ready for conjugate models.

From here, do any analysis you'd like!

# Figure
import matplotlib.pyplot as plt

fig, axes = plt.subplots(ncols=2)

ax = axes[0]
ax = posterior.plot_pdf(ax=ax, label="posterior")
prior.plot_pdf(ax=ax, label="prior")
ax.axvline(x=x / N, color="black", ymax=0.05, label="MLE")
ax.set_title("Success Rate")
ax.legend()

ax = axes[1]
posterior_predictive.plot_pmf(ax=ax, label="posterior predictive")
prior_predictive.plot_pmf(ax=ax, label="prior predictive")
ax.axvline(x=x, color="black", ymax=0.05, label="Sample")
ax.set_title("Number of Successes")
ax.legend()
plt.show()

More examples on in the documentation.

Contributing

If you are interested in contributing, check out the contributing guidelines