|

A Detailed Implementation on Equinox with JAX Native Modules, Filtered Transforms, Stateful Layers, and End-to-End Training Workflows

In this tutorial, we discover Equinox, a light-weight and elegant neural community library constructed on JAX, and present learn how to use it. We start by understanding how eqx.Module treats fashions as PyTrees, which makes parameter dealing with, transformation, and serialization really feel easy and express. As we transfer ahead, we work by way of static fields, filtered transformations corresponding to filter_jit and filter_grad, PyTree manipulation utilities, stateful layers corresponding to BatchNorm, and an entire end-to-end coaching workflow for a toy regression drawback. Throughout the tutorial, we focus on writing clear, executable code that demonstrates not solely how Equinox works but in addition why it suits so nicely into the JAX ecosystem for analysis and sensible experimentation.

!pip set up equinox optax jaxtyping matplotlib -q


import jax
import jax.numpy as jnp
import equinox as eqx
import optax
from jaxtyping import Array, Float, Int, PRNGKeyArray
from typing import Optional
import matplotlib.pyplot as plt
import time


print(f"JAX model   : {jax.__version__}")
print(f"Equinox model: {eqx.__version__}")
print(f"Devices       : {jax.gadgets()}")


print("n" + "="*60)
print("SECTION 1: eqx.Module fundamentals")
print("="*60)


class Linear(eqx.Module):
   weight: Float[Array, "out in"]
   bias:   Float[Array, "out"]


   def __init__(self, in_size: int, out_size: int, *, key: PRNGKeyArray):
       wkey, bkey = jax.random.break up(key)
       self.weight = jax.random.regular(wkey, (out_size, in_size)) * 0.1
       self.bias = jax.random.regular(bkey, (out_size,)) * 0.01


   def __call__(self, x: Float[Array, "in"]) -> Float[Array, "out"]:
       return self.weight @ x + self.bias




key = jax.random.PRNGKey(0)
lin = Linear(4, 2, key=key)


leaves, treedef = jax.tree_util.tree_flatten(lin)
print("Leaves shapes:", [l.shape for l in leaves])
print("Treedef:", treedef)


print("n" + "="*60)
print("SECTION 2: Static fields")
print("="*60)


class Conv1dBlock(eqx.Module):
   conv:        eqx.nn.Conv1d
   norm:        eqx.nn.LayerNorm
   activation:  str = eqx.discipline(static=True)


   def __init__(self, channels: int, kernel: int, activation: str, *, key: PRNGKeyArray):
       self.conv       = eqx.nn.Conv1d(channels, channels, kernel, padding="identical", key=key)
       self.norm       = eqx.nn.LayerNorm((channels,))
       self.activation = activation


   def __call__(self, x: Float[Array, "C L"]) -> Float[Array, "C L"]:
       x = self.conv(x)
       x = jax.vmap(self.norm)(x.T).T
       if self.activation == "relu":
           return jax.nn.relu(x)
       elif self.activation == "gelu":
           return jax.nn.gelu(x)
       return x




key, subkey = jax.random.break up(key)
block = Conv1dBlock(8, 3, "gelu", key=subkey)
x_seq = jnp.ones((8, 16))
out = block(x_seq)
print(f"Conv1dBlock output form: {out.form}")

We arrange the complete Equinox surroundings by putting in the required libraries and importing JAX, Equinox, Optax, Jaxtyping, Matplotlib, and different necessities. We instantly confirm the runtime by printing the JAX and Equinox variations and the obtainable gadgets, which helps us affirm that our Colab surroundings is prepared for execution. We then start with the foundations of Equinox by defining a easy Linear module, creating an occasion of it, and inspecting its PyTree leaves and construction earlier than introducing a Conv1dBlock that demonstrates how static fields and learnable layers work collectively in observe.

print("n" + "="*60)
print("SECTION 3: Filtered transforms")
print("="*60)


class MLP(eqx.Module):
   layers: listing
   dropout: eqx.nn.Dropout


   def __init__(self, in_size, hidden, out_size, *, key: PRNGKeyArray):
       k1, k2, k3 = jax.random.break up(key, 3)
       self.layers  = [
           eqx.nn.Linear(in_size, hidden, key=k1),
           eqx.nn.Linear(hidden,  hidden, key=k2),
           eqx.nn.Linear(hidden,  out_size, key=k3),
       ]
       self.dropout = eqx.nn.Dropout(p=0.1)


   def __call__(self, x: Float[Array, "in"], *, key: Optional[PRNGKeyArray] = None) -> Float[Array, "out"]:
       for layer in self.layers[:-1]:
           x = jax.nn.relu(layer(x))
           if key is just not None:
               key, subkey = jax.random.break up(key)
               x = self.dropout(x, key=subkey)
       return self.layers[-1](x)




key, mk = jax.random.break up(key)
mlp = MLP(8, 32, 4, key=mk)


@eqx.filter_jit
def ahead(mannequin, x, *, key):
   return mannequin(x, key=key)


x_in  = jnp.ones((8,))
key, fk = jax.random.break up(key)
y_out = ahead(mlp, x_in, key=fk)
print(f"MLP output: {y_out}")


@eqx.filter_jit
def loss_fn(mannequin: MLP,
           x: Float[Array, "B in"],
           y: Float[Array, "B out"],
           key: PRNGKeyArray) -> Float[Array, ""]:
   keys  = jax.random.break up(key, x.form[0])
   preds = jax.vmap(mannequin)(x)
   return jnp.imply((preds - y) ** 2)


grad_fn = eqx.filter_grad(loss_fn)


key, dk = jax.random.break up(key)
X = jax.random.regular(dk, (16, 8))
Y = jax.random.regular(dk, (16, 4))
grads = grad_fn(mlp, X, Y, dk)
print(f"Grad of first layer weight: form={grads.layers[0].weight.form}, norm={jnp.linalg.norm(grads.layers[0].weight):.4f}")

We focus on Equinox’s filtered transformations by constructing an MLP that features each linear layers and dropout. We use filter_jit to compile the ahead cross whereas permitting the mannequin to comprise non-array fields, and we use filter_grad to compute gradients just for array leaves that ought to truly take part in studying. By operating a ahead cross and then evaluating gradients on artificial information, we see how Equinox cleanly bridges mannequin definition and differentiable computation in a JAX-friendly manner.

print("n" + "="*60)
print("SECTION 4: PyTree manipulation")
print("="*60)


arrays, non_arrays = eqx.partition(mlp, eqx.is_array)
print("Non-array leaves (construction solely):", jax.tree_util.tree_leaves(non_arrays))


trainable_filter = jax.tree_util.tree_map(
   lambda _: True, mlp
)
trainable_filter = eqx.tree_at(
   lambda m: (m.layers[0].weight, m.layers[0].bias),
   trainable_filter,
   change=(False, False),
)
trainable, frozen = eqx.partition(mlp, trainable_filter)
print("Frozen params (first layer weight form):", frozen.layers[0].weight.form)
print("Trainable first-layer weight is sentinel:", trainable.layers[0].weight)


key, nk = jax.random.break up(key)
new_weight = jax.random.regular(nk, mlp.layers[0].weight.form)
mlp_updated = eqx.tree_at(lambda m: m.layers[0].weight, mlp, new_weight)
print("Updated first-layer weight norm:", jnp.linalg.norm(mlp_updated.layers[0].weight).merchandise())


print("n" + "="*60)
print("SECTION 5: Stateful layers — BatchNorm with inference mode")
print("="*60)


class BNModel(eqx.Module):
   linear1: eqx.nn.Linear
   bn:      eqx.nn.BatchNorm
   linear2: eqx.nn.Linear


   def __init__(self, in_f, hidden, out_f, *, key: PRNGKeyArray):
       k1, k2 = jax.random.break up(key)
       self.linear1 = eqx.nn.Linear(in_f, hidden, key=k1)
       self.bn      = eqx.nn.BatchNorm(hidden, axis_name="batch")
       self.linear2 = eqx.nn.Linear(hidden, out_f, key=k2)


   def __call__(self, x, state, *, inference: bool = False):
       x, state = self.bn(jax.nn.relu(self.linear1(x)), state, inference=inference)
       return self.linear2(x), state




key, bk = jax.random.break up(key)
bn_model, bn_state = eqx.nn.make_with_state(BNModel)(4, 16, 2, key=bk)


@eqx.filter_jit
def train_step_bn(mannequin, state, x):
   def single(x):
       return mannequin(x, state)
   outs, states = jax.vmap(single, axis_name="batch", out_axes=(0, None))(x)
   return outs, states


x_batch = jax.random.regular(key, (8, 4))
preds, bn_state = train_step_bn(bn_model, bn_state, x_batch)
print(f"BNModel output form: {preds.form}")

We discover PyTree manipulation utilities that make Equinox particularly versatile for analysis workflows. We partition the mannequin into array and non-array elements, create a trainable filter to freeze the primary layer, and use tree_at to carry out an immutable replace on a selected parameter with out rewriting the entire mannequin. We then lengthen the tutorial to stateful layers by defining a BatchNorm-based mannequin, creating each the mannequin and its state, and operating a batched training-style cross that returns up to date state info.

print("n" + "="*60)
print("SECTION 6: Full coaching loop (ResNet MLP on noisy sine)")
print("="*60)


class ResBlock(eqx.Module):
   fc1: eqx.nn.Linear
   fc2: eqx.nn.Linear
   proj: Optional[eqx.nn.Linear]


   def __init__(self, dimension: int, *, key: PRNGKeyArray):
       k1, k2 = jax.random.break up(key)
       self.fc1  = eqx.nn.Linear(dimension, dimension, key=k1)
       self.fc2  = eqx.nn.Linear(dimension, dimension, key=k2)
       self.proj = None


   def __call__(self, x):
       residual = x
       x = jax.nn.gelu(self.fc1(x))
       x = self.fc2(x)
       return jax.nn.gelu(x + residual)




class ResNetMLP(eqx.Module):
   embed:   eqx.nn.Linear
   blocks:  listing
   head:    eqx.nn.Linear


   def __init__(self, in_size, hidden, out_size, n_blocks, *, key: PRNGKeyArray):
       keys = jax.random.break up(key, n_blocks + 2)
       self.embed  = eqx.nn.Linear(in_size, hidden, key=keys[0])
       self.blocks = [ResBlock(hidden, key=keys[i+1]) for i in vary(n_blocks)]
       self.head   = eqx.nn.Linear(hidden, out_size, key=keys[-1])


   def __call__(self, x):
       x = jax.nn.gelu(self.embed(x))
       for block in self.blocks:
           x = block(x)
       return self.head(x)




def make_dataset(n: int, key: PRNGKeyArray):
   xk, nk = jax.random.break up(key)
   x = jax.random.uniform(xk, (n, 1), minval=-1.0, maxval=1.0)
   y = jnp.sin(2 * jnp.pi * x) + 0.1 * jax.random.regular(nk, (n, 1))
   return x, y


key, dk = jax.random.break up(key)
X_train, Y_train = make_dataset(2048, dk)
key, dk = jax.random.break up(key)
X_val,   Y_val   = make_dataset(512,  dk)


key, mk = jax.random.break up(key)
mannequin = ResNetMLP(1, 64, 1, n_blocks=4, key=mk)


schedule = optax.warmup_cosine_decay_schedule(
   init_value=0.0, peak_value=3e-3,
   warmup_steps=200, decay_steps=2000
)
optimiser = optax.chain(optax.clip_by_global_norm(1.0), optax.adam(schedule))
opt_state = optimiser.init(eqx.filter(mannequin, eqx.is_array))


@eqx.filter_jit
def train_step(mannequin, opt_state, x, y):
   def compute_loss(mannequin, x, y):
       preds = jax.vmap(mannequin)(x)
       return jnp.imply((preds - y) ** 2)


   loss, grads = eqx.filter_value_and_grad(compute_loss)(mannequin, x, y)
   updates, opt_state_new = optimiser.replace(
       grads, opt_state, eqx.filter(mannequin, eqx.is_array)
   )
   model_new = eqx.apply_updates(mannequin, updates)
   return model_new, opt_state_new, loss




@eqx.filter_jit
def consider(mannequin, x, y):
   preds = jax.vmap(mannequin)(x)
   return jnp.imply((preds - y) ** 2)

We construct a deeper end-to-end studying instance by defining a residual block and a ResNetMLP mannequin for a loud sine regression job. We generate artificial coaching and validation datasets, initialize the mannequin, configure a warmup cosine studying charge schedule, and put together the optimizer state utilizing solely the mannequin’s array leaves. We additionally outline the jitted train_step and consider features, which give the core coaching and validation mechanics for the complete workflow.

BATCH  = 128
EPOCHS = 30
steps_per_epoch = len(X_train) // BATCH
train_losses, val_losses = [], []


t0 = time.time()
for epoch in vary(EPOCHS):
   key, sk = jax.random.break up(key)
   perm = jax.random.permutation(sk, len(X_train))
   X_s, Y_s = X_train[perm], Y_train[perm]


   epoch_loss = 0.0
   for step in vary(steps_per_epoch):
       xb = X_s[step*BATCH:(step+1)*BATCH]
       yb = Y_s[step*BATCH:(step+1)*BATCH]
       mannequin, opt_state, loss = train_step(mannequin, opt_state, xb, yb)
       epoch_loss += loss.merchandise()


   val_loss = consider(mannequin, X_val, Y_val).merchandise()
   train_losses.append(epoch_loss / steps_per_epoch)
   val_losses.append(val_loss)


   if (epoch + 1) % 5 == 0:
       print(f"Epoch {epoch+1:3d}/{EPOCHS}  "
             f"train_loss={train_losses[-1]:.5f}  "
             f"val_loss={val_losses[-1]:.5f}")


print(f"nTotal coaching time: {time.time()-t0:.1f}s")


print("n" + "="*60)
print("SECTION 7: Save & load mannequin weights")
print("="*60)


eqx.tree_serialise_leaves("model_weights.eqx", mannequin)


key, mk2 = jax.random.break up(key)
model_skeleton = ResNetMLP(1, 64, 1, n_blocks=4, key=mk2)
model_loaded   = eqx.tree_deserialise_leaves("model_weights.eqx", model_skeleton)


diff = jnp.max(jnp.abs(
   jax.tree_util.tree_leaves(eqx.filter(mannequin, eqx.is_array))[0]
 - jax.tree_util.tree_leaves(eqx.filter(model_loaded, eqx.is_array))[0]
))
print(f"Max weight distinction after reload: {diff:.2e}  (needs to be 0.0)")


fig, axes = plt.subplots(1, 2, figsize=(12, 4))


axes[0].plot(train_losses, label="Train MSE", colour="#4C72B0")
axes[0].plot(val_losses,   label="Val MSE",   colour="#DD8452", linestyle="--")
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("MSE")
axes[0].set_title("Training curves")
axes[0].legend()
axes[0].grid(True, alpha=0.3)


x_plot  = jnp.linspace(-1, 1, 300).reshape(-1, 1)
y_true  = jnp.sin(2 * jnp.pi * x_plot)
y_pred  = jax.vmap(mannequin)(x_plot)


axes[1].scatter(X_val[:100], Y_val[:100], s=10, alpha=0.4, colour="grey", label="Data")
axes[1].plot(x_plot, y_true, colour="#4C72B0",  linewidth=2, label="True f(x)")
axes[1].plot(x_plot, y_pred, colour="#DD8452", linewidth=2, linestyle="--", label="Predicted")
axes[1].set_xlabel("x")
axes[1].set_ylabel("y")
axes[1].set_title("Sine regression match")
axes[1].legend()
axes[1].grid(True, alpha=0.3)


plt.tight_layout()
plt.savefig("equinox_tutorial.png", dpi=150)
plt.present()
print("nDone! Plot saved to equinox_tutorial.png")


print("n" + "="*60)
print("BONUS: eqx.filter_jit + form inference debug tip")
print("="*60)


jaxpr = jax.make_jaxpr(jax.vmap(mannequin))(x_plot)
n_eqns = len(jaxpr.jaxpr.eqns)
print(f"Compiled ResNetMLP jaxpr has {n_eqns} equations (ops) for batch enter {x_plot.form}")
BATCH  = 128
EPOCHS = 30
steps_per_epoch = len(X_train) // BATCH
train_losses, val_losses = [], []


t0 = time.time()
for epoch in vary(EPOCHS):
   key, sk = jax.random.break up(key)
   perm = jax.random.permutation(sk, len(X_train))
   X_s, Y_s = X_train[perm], Y_train[perm]


   epoch_loss = 0.0
   for step in vary(steps_per_epoch):
       xb = X_s[step*BATCH:(step+1)*BATCH]
       yb = Y_s[step*BATCH:(step+1)*BATCH]
       mannequin, opt_state, loss = train_step(mannequin, opt_state, xb, yb)
       epoch_loss += loss.merchandise()


   val_loss = consider(mannequin, X_val, Y_val).merchandise()
   train_losses.append(epoch_loss / steps_per_epoch)
   val_losses.append(val_loss)


   if (epoch + 1) % 5 == 0:
       print(f"Epoch {epoch+1:3d}/{EPOCHS}  "
             f"train_loss={train_losses[-1]:.5f}  "
             f"val_loss={val_losses[-1]:.5f}")


print(f"nTotal coaching time: {time.time()-t0:.1f}s")


print("n" + "="*60)
print("SECTION 7: Save & load mannequin weights")
print("="*60)


eqx.tree_serialise_leaves("model_weights.eqx", mannequin)


key, mk2 = jax.random.break up(key)
model_skeleton = ResNetMLP(1, 64, 1, n_blocks=4, key=mk2)
model_loaded   = eqx.tree_deserialise_leaves("model_weights.eqx", model_skeleton)


diff = jnp.max(jnp.abs(
   jax.tree_util.tree_leaves(eqx.filter(mannequin, eqx.is_array))[0]
 - jax.tree_util.tree_leaves(eqx.filter(model_loaded, eqx.is_array))[0]
))
print(f"Max weight distinction after reload: {diff:.2e}  (needs to be 0.0)")


fig, axes = plt.subplots(1, 2, figsize=(12, 4))


axes[0].plot(train_losses, label="Train MSE", colour="#4C72B0")
axes[0].plot(val_losses,   label="Val MSE",   colour="#DD8452", linestyle="--")
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("MSE")
axes[0].set_title("Training curves")
axes[0].legend()
axes[0].grid(True, alpha=0.3)


x_plot  = jnp.linspace(-1, 1, 300).reshape(-1, 1)
y_true  = jnp.sin(2 * jnp.pi * x_plot)
y_pred  = jax.vmap(mannequin)(x_plot)


axes[1].scatter(X_val[:100], Y_val[:100], s=10, alpha=0.4, colour="grey", label="Data")
axes[1].plot(x_plot, y_true, colour="#4C72B0",  linewidth=2, label="True f(x)")
axes[1].plot(x_plot, y_pred, colour="#DD8452", linewidth=2, linestyle="--", label="Predicted")
axes[1].set_xlabel("x")
axes[1].set_ylabel("y")
axes[1].set_title("Sine regression match")
axes[1].legend()
axes[1].grid(True, alpha=0.3)


plt.tight_layout()
plt.savefig("equinox_tutorial.png", dpi=150)
plt.present()
print("nDone! Plot saved to equinox_tutorial.png")


print("n" + "="*60)
print("BONUS: eqx.filter_jit + form inference debug tip")
print("="*60)


jaxpr = jax.make_jaxpr(jax.vmap(mannequin))(x_plot)
n_eqns = len(jaxpr.jaxpr.eqns)
print(f"Compiled ResNetMLP jaxpr has {n_eqns} equations (ops) for batch enter {x_plot.form}")

We run the whole coaching loop throughout a number of epochs, shuffle the information, course of mini-batches, and monitor each coaching and validation losses over time. We then serialize the educated mannequin with Equinox utilities, reconstruct an identical skeleton mannequin, confirm that deserialization restores the weights accurately, and visualize the discovered match and loss curves. Also, we examine the compiled computation graph utilizing jax.make_jaxpr, which offers a helpful debugging and introspection view of how the educated Equinox mannequin is executed below JAX.

In conclusion, we constructed a powerful sensible understanding of how Equinox helps us write clear, modular, and JAX-native deep studying code with out including pointless abstraction. We noticed learn how to outline customized modules, handle static and trainable parts, apply filtered transformations safely, work with stateful layers, prepare a residual MLP, save and reload mannequin weights, and examine compiled computations. In doing so, we skilled how Equinox provides us the flexibleness of uncooked JAX whereas nonetheless offering the construction wanted for contemporary mannequin growth. As a consequence, we got here away with an entire hands-on basis that prepares us to make use of Equinox confidently for extra superior machine studying experiments and analysis workflows.


Check out the Full Codes with Notebook here. Also, be at liberty to comply with us on Twitter and don’t neglect to hitch our 130k+ ML SubReddit and Subscribe to our Newsletter. Wait! are you on telegram? now you can join us on telegram as well.

Need to accomplice with us for selling your GitHub Repo OR Hugging Face Page OR Product Release OR Webinar and many others.? Connect with us

The put up A Detailed Implementation on Equinox with JAX Native Modules, Filtered Transforms, Stateful Layers, and End-to-End Training Workflows appeared first on MarkTechPost.

Similar Posts