Skip to content

pymgcv: Generalized Additive Models in Python

pymgcv provides a python interface and visualisation tools for R's powerful mgcv library for fitting Generalized Additive Models (GAMs).

Installation

Warning

The package is very new. Expect breaking changes without waning until the package stabilises.

Installing the python package only includes the python package and dependencies. This means an R installation with mgcv is also required. Conda and pixi provide two convenient options for handling both Python and R dependencies:

Using conda:

  • Install conda e.g. miniforge
  • Install uv (or use pip).
conda create --name my_env python r-base r-mgcv
conda activate my_env
uv add pymgcv
source .venv/bin/activate  # Or shell/OS specific acitivation  

Using pixi:

  • Install pixi
    pixi init
    pixi add python r-base r-mgcv
    pixi add --pypi pymgcv
    pixi shell
    

Using either method the below example should now run e.g. in the terminal after running python, or in an IDE after selecting the pixi/conda environment

What are GAMs?

Generalized Additive Models (GAMs) are a flexible class of statistical models that extend linear models by allowing non-linear relationships between predictors and the response variable. For example the model may have the form

$$ g(\mathbb{E}[Y]) = \beta_0 + \sum_{j=1}^p f_j(x_j), $$ Where:

  • \(g\) is the link function, which transforms the expected value of the response variable to a space where modelling with the sum of smooths is reasonable.
  • \(f_j\) are smooth functions (e.g. splines) which capture the non-linear relationship between the features and the response.
  • Bivariate \(f(x_1, x_2)\) (and multivariate) smooths are also possible, when interactions are important.

Why GAMs?

  • Flexibility: Capture non-linear relationships automatically
  • Interpretability: The additive nature allows each term to be visualized and understood separately
  • Statistical rigor: Built-in smoothing parameter estimation and uncertainty quantification

Simple example

import numpy as np
import pandas as pd

from pymgcv.gam import GAM
from pymgcv.plot import plot_gam
from pymgcv.terms import L, S

rng = np.random.default_rng(1)
n = 100
x0, x1, x2, x3 = [rng.uniform(-1, 1, n) for _ in range(4)]
y = (
    0.5 * x0 +
    np.sin(np.pi * x1) +
    np.cos(np.pi * x2) * np.sin(np.pi * x3) +
    rng.normal(0, 0.3, n)
)
data = pd.DataFrame({'x0': x0, 'x1': x1, 'x2': x2, 'x3': x3, 'y': y})

gam = GAM({'y': L('x0') + S('x1') + S('x2', 'x3')})
gam.fit(data)
fig, ax = plot_gam(gam, residuals=True, ncols=3) # plot partial effects
fig.set_size_inches(9, 3)
fig.show()

img

What next?

The key bits of information to explore:

  1. Terms: The types of terms supported by pymgcv (e.g. smooths, linear, interactions, etc)
  2. Basis Functions: The different types of basis functions available
  3. The examples in the side bar!