pymgcv: Generalized Additive Models in Python
pymgcv provides a python interface and visualisation tools for R's powerful mgcv library for fitting Generalized Additive Models (GAMs).
Installation
Warning
The package is very new. Expect breaking changes without waning until the package stabilises.
Installing the python package only includes the python package and dependencies. This means an R installation with mgcv
is also required.
Conda and pixi provide two convenient options for handling both Python and R dependencies:
Using conda:
conda create --name my_env python r-base r-mgcv
conda activate my_env
uv add pymgcv
source .venv/bin/activate # Or shell/OS specific acitivation
Using pixi:
- Install pixi
Using either method the below example should now run e.g. in the terminal after running python
,
or in an IDE after selecting the pixi/conda environment
What are GAMs?
Generalized Additive Models (GAMs) are a flexible class of statistical models that extend linear models by allowing non-linear relationships between predictors and the response variable. For example the model may have the form
$$ g(\mathbb{E}[Y]) = \beta_0 + \sum_{j=1}^p f_j(x_j), $$ Where:
- \(g\) is the link function, which transforms the expected value of the response variable to a space where modelling with the sum of smooths is reasonable.
- \(f_j\) are smooth functions (e.g. splines) which capture the non-linear relationship between the features and the response.
- Bivariate \(f(x_1, x_2)\) (and multivariate) smooths are also possible, when interactions are important.
Why GAMs?
- Flexibility: Capture non-linear relationships automatically
- Interpretability: The additive nature allows each term to be visualized and understood separately
- Statistical rigor: Built-in smoothing parameter estimation and uncertainty quantification
Simple example
import numpy as np
import pandas as pd
from pymgcv.gam import GAM
from pymgcv.plot import plot_gam
from pymgcv.terms import L, S
rng = np.random.default_rng(1)
n = 100
x0, x1, x2, x3 = [rng.uniform(-1, 1, n) for _ in range(4)]
y = (
0.5 * x0 +
np.sin(np.pi * x1) +
np.cos(np.pi * x2) * np.sin(np.pi * x3) +
rng.normal(0, 0.3, n)
)
data = pd.DataFrame({'x0': x0, 'x1': x1, 'x2': x2, 'x3': x3, 'y': y})
gam = GAM({'y': L('x0') + S('x1') + S('x2', 'x3')})
gam.fit(data)
fig, ax = plot_gam(gam, residuals=True, ncols=3) # plot partial effects
fig.set_size_inches(9, 3)
fig.show()
What next?
The key bits of information to explore:
- Terms: The types of terms supported by pymgcv (e.g. smooths, linear, interactions, etc)
- Basis Functions: The different types of basis functions available
- The examples in the side bar!