A Coding Implementation on Spatial Graph Neural Networks for Urban Function Inference Using city2graph, OSMnx, and PyTorch Geometric
In this tutorial, we construct an end-to-end spatial graph studying pipeline utilizing city2graph. We begin by gathering actual city POI information and avenue community data from OpenStreetMap, with an artificial fallback to make sure the workflow stays dependable. We then engineer spatial options, assemble a number of proximity graph households, and examine how completely different graph-building methods symbolize the identical city surroundings. After that, we create each heterogeneous and homogeneous graph constructions, convert them into PyTorch Geometric format, and prepare a GraphSAGE mannequin to foretell POI classes from spatial construction. Through this course of, we combine geospatial information processing, graph building, and GNN-based city perform inference right into a single sensible workflow.
Installing city2graph and Importing Geospatial and Graph Learning Libraries
!pip -q set up "city2graph[cpu]" osmnx contextily scikit-learn 2>/dev/null
import warnings, numpy as np, pandas as pd, geopandas as gpd
warnings.filterwarnings("ignore")
from shapely.geometry import Point
import matplotlib.pyplot as plt
import city2graph as c2g
print("city2graph model:", getattr(c2g, "__version__", "unknown"))
print("PyTorch / PyG out there:", c2g.is_torch_available())
import torch
import torch.nn.practical as F
from torch_geometric.nn import SAGEConv, to_hetero
from torch_geometric.utils import to_undirected
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import accuracy_score, f1_score
from sklearn.decomposition import PCA
SEED = 42
np.random.seed(SEED); torch.manual_seed(SEED)
We start by putting in the required libraries and importing the geospatial, graph studying, and machine studying instruments used all through the tutorial. We confirm that city2graph and PyTorch Geometric can be found so the remainder of the workflow can run correctly. We additionally set a set random seed to make the graph building, coaching break up, and mannequin outcomes extra reproducible.
Collecting OpenStreetMap POI Data with a Synthetic Fallback
CENTER = (35.6595, 139.7005)
DIST_M = 1100
TAG_QUERIES = {
"meals": {"amenity": ["restaurant", "cafe", "fast_food", "bar", "pub"]},
"retail": {"store": True},
"training": {"amenity": ["school", "university", "college", "kindergarten", "library"]},
"well being": {"amenity": ["hospital", "clinic", "pharmacy", "doctors", "dentist"]},
}
def to_points(gdf):
g = gdf.copy()
g["geometry"] = g.geometry.representative_point()
return g
poi_gdf, segments_gdf = None, None
strive:
import osmnx as ox
ox.settings.use_cache = True
ox.settings.log_console = False
frames = []
for label, tags in TAG_QUERIES.gadgets():
strive:
f = ox.features_from_point(CENTER, tags=tags, dist=DIST_M)
f = f[f.geometry.notna()]
if len(f):
f = to_points(f)[["geometry"]].copy()
f["category"] = label
frames.append(f)
besides Exception as e:
print(f" (skip {label}: {e})")
if not frames:
increase RuntimeError("No POIs returned from Overpass.")
poi_gdf = gpd.GeoDataBody(pd.concat(frames, ignore_index=True), crs="EPSG:4326")
G = ox.graph_from_point(CENTER, dist=DIST_M, network_type="stroll")
segments_gdf = ox.graph_to_gdfs(G, nodes=False, edges=True).reset_index(drop=True)[["geometry"]]
print(f"OSM acquisition OK -> {len(poi_gdf)} POIs, {len(segments_gdf)} avenue segments")
besides Exception as e:
print(f"OSM unavailable ({e}) -> producing artificial clustered POIs.")
rng = np.random.default_rng(SEED)
cats = checklist(TAG_QUERIES.keys())
facilities = rng.uniform(-0.01, 0.01, measurement=(8, 2)) + np.array(CENTER[::-1])
rows = []
for ci, c in enumerate(facilities):
dom = cats[ci % len(cats)]
n = rng.integers(40, 90)
pts = c + rng.regular(0, 0.0016, measurement=(n, 2))
for (lon, lat) in pts:
cat = dom if rng.random() < 0.75 else rng.alternative(cats)
rows.append({"geometry": Point(lon, lat), "class": cat})
poi_gdf = gpd.GeoDataBody(rows, crs="EPSG:4326")
segments_gdf = None
print(f"Synthetic dataset -> {len(poi_gdf)} POIs")
if len(poi_gdf) > 700:
poi_gdf = poi_gdf.pattern(700, random_state=SEED).reset_index(drop=True)
metric_crs = poi_gdf.estimate_utm_crs()
poi_gdf = poi_gdf.to_crs(metric_crs).reset_index(drop=True)
if segments_gdf is just not None:
segments_gdf = segments_gdf.to_crs(metric_crs)
print("Class steadiness:n", poi_gdf["category"].value_counts())
We acquire actual POI information from OpenStreetMap round Shibuya, Tokyo, and group the places into broad city perform classes resembling meals, retail, training, and well being. We additionally obtain the walkable avenue community in order that the POIs can later be linked with urban-form options. If the OSM request fails, we generate an artificial clustered dataset, which retains the tutorial runnable even when on-line information entry is unavailable.
Engineering Spatial Features and Building Proximity Graph Families
poi_gdf["cx"] = poi_gdf.geometry.x
poi_gdf["cy"] = poi_gdf.geometry.y
coords = poi_gdf[["cx", "cy"]].to_numpy()
nn = NearestNeighbors(radius=150.0).match(coords)
poi_gdf["local_density"] = [len(idx) - 1 for idx in nn.radius_neighbors(coords, return_distance=False)]
if segments_gdf is just not None and len(segments_gdf):
strive:
joined = gpd.sjoin_nearest(poi_gdf[["geometry"]], segments_gdf[["geometry"]],
distance_col="dist_street")
poi_gdf["dist_street"] = joined.groupby(stage=0)["dist_street"].min().reindex(poi_gdf.index).fillna(0.0)
besides Exception:
poi_gdf["dist_street"] = 0.0
else:
poi_gdf["dist_street"] = 0.0
poi_gdf["category"] = poi_gdf["category"].astype("class")
poi_gdf["label"] = poi_gdf["category"].cat.codes.astype(int)
CLASS_NAMES = checklist(poi_gdf["category"].cat.classes)
print("Classes:", CLASS_NAMES)
def graph_stats(identify, builder):
strive:
nodes, edges = builder()
deg = pd.Series(np.r_[edges.index.get_level_values(0),
edges.index.get_level_values(1)]).value_counts()
return identify, len(edges), spherical(deg.imply(), 2), (nodes, edges)
besides Exception as e:
return identify, f"ERR: {e}", None, None
builders = {
"KNN (ok=8)": lambda: c2g.knn_graph(poi_gdf, distance_metric="euclidean", ok=8, as_nx=False),
"Delaunay": lambda: c2g.delaunay_graph(poi_gdf, as_nx=False),
"Gabriel": lambda: c2g.gabriel_graph(poi_gdf, as_nx=False),
"RNG": lambda: c2g.relative_neighborhood_graph(poi_gdf, as_nx=False),
"EMST": lambda: c2g.euclidean_minimum_spanning_tree(poi_gdf, as_nx=False),
"Waxman": lambda: c2g.waxman_graph(poi_gdf, distance_metric="euclidean", r0=150, beta=0.6),
}
print("n--- Proximity graph comparability ---")
print(f"{'graph':<14}{'#edges':>10}{'avg_degree':>12}")
constructed = {}
for nm, b in builders.gadgets():
identify, ne, avgdeg, payload = graph_stats(nm, b)
print(f"{identify:<14}{str(ne):>10}{str(avgdeg):>12}")
if payload: constructed[nm] = payload
fig, axes = plt.subplots(1, 3, figsize=(16, 5))
for ax, key in zip(axes, ["KNN (k=8)", "Delaunay", "EMST"]):
if key in constructed:
n_, e_ = constructed[key]
e_.plot(ax=ax, linewidth=0.4, coloration="#3b7dd8", alpha=0.6)
poi_gdf.plot(ax=ax, markersize=4, coloration="#d83b5c")
ax.set_title(key); ax.set_axis_off()
plt.suptitle("Spatial graph topologies on the identical POI set", y=1.02)
plt.tight_layout(); plt.present()
We engineer spatial options for every POI by extracting its projected coordinates, calculating native density, and estimating distance to the closest avenue section. We then assign class labels and construct a number of households of proximity graphs, together with KNN, Delaunay, Gabriel, RNG, EMST, and Waxman. We examine their edge counts and common levels, then visualize chosen graph topologies to see how otherwise they join the identical set of POIs.
Constructing Heterogeneous and Homogeneous Graphs in PyTorch Geometric
nodes_dict = {}
for cat in CLASS_NAMES:
sub = poi_gdf[poi_gdf["category"] == cat].copy().reset_index(drop=True)
nodes_dict[cat] = sub[["geometry", "cx", "cy", "local_density"]]
strive:
_, bridge_edges = c2g.bridge_nodes(nodes_dict, proximity_method="knn", ok=3,
distance_metric="euclidean")
hetero = c2g.gdf_to_pyg(
nodes_dict, bridge_edges,
node_feature_cols={cat: ["cx", "cy", "local_density"] for cat in CLASS_NAMES},
)
print("nHeteroData node varieties:", hetero.node_types)
print("HeteroData edge varieties:")
for et in hetero.edge_types:
print(f" {et}: {hetero[et].edge_index.form[1]} edges")
besides Exception as e:
hetero = None
print("Heterogeneous construct skipped:", e)
nodes, edges = c2g.knn_graph(poi_gdf, distance_metric="euclidean", ok=8, as_nx=False)
deg = pd.Series(np.r_[edges.index.get_level_values(0),
edges.index.get_level_values(1)]).value_counts()
nodes["degree"] = deg.reindex(nodes.index).fillna(0).astype(float)
for col in ["cx", "cy", "local_density", "dist_street", "label"]:
if col not in nodes.columns:
nodes[col] = poi_gdf.loc[nodes.index, col].values
FEATS = ["cx", "cy", "local_density", "dist_street", "degree"]
nodes[FEATS] = StandardScaler().fit_transform(nodes[FEATS].astype(float))
information = c2g.gdf_to_pyg(nodes, edges, node_feature_cols=FEATS, node_label_cols=["label"])
information.edge_index = to_undirected(information.edge_index)
information.x = information.x.float()
y = information.y.lengthy().view(-1)
N, num_classes = information.num_nodes, int(y.max()) + 1
print(f"nHomogeneous Data: {N} nodes, {information.edge_index.form[1]} directed-edges, "
f"{information.x.form[1]} options, {num_classes} courses")
We assemble a heterogeneous multi-layer graph by separating POIs into node varieties based mostly on their city perform classes. We then use bridge edges to attach close by nodes throughout completely different layers and convert the end result into PyTorch Geometric HeteroData format. After that, we construct a homogeneous KNN graph, connect diploma and engineered options, standardize them, and put together the ultimate PyG Data object for GraphSAGE coaching.
Defining and Training a GraphSAGE Model for POI Classification
perm = torch.randperm(N, generator=torch.Generator().manual_seed(SEED))
n_tr, n_va = int(0.6 * N), int(0.2 * N)
train_mask = torch.zeros(N, dtype=torch.bool); train_mask[perm[:n_tr]] = True
val_mask = torch.zeros(N, dtype=torch.bool); val_mask[perm[n_tr:n_tr + n_va]] = True
test_mask = torch.zeros(N, dtype=torch.bool); test_mask[perm[n_tr + n_va:]] = True
class GraphSAGE(torch.nn.Module):
def __init__(self, in_dim, hidden, out_dim, p=0.3):
tremendous().__init__()
self.c1 = SAGEConv(in_dim, hidden)
self.c2 = SAGEConv(hidden, hidden)
self.lin = torch.nn.Linear(hidden, out_dim)
self.p = p
def ahead(self, x, ei, return_emb=False):
h = F.relu(self.c1(x, ei))
h = F.dropout(h, p=self.p, coaching=self.coaching)
h = F.relu(self.c2(h, ei))
out = self.lin(h)
return (out, h) if return_emb else out
mannequin = GraphSAGE(information.x.form[1], 64, num_classes)
choose = torch.optim.Adam(mannequin.parameters(), lr=0.01, weight_decay=5e-4)
def consider(masks):
mannequin.eval()
with torch.no_grad():
pred = mannequin(information.x, information.edge_index).argmax(1)
yt, yp = y[mask].numpy(), pred[mask].numpy()
return accuracy_score(yt, yp), f1_score(yt, yp, common="macro")
print("n--- Training GraphSAGE ---")
best_val, best_state = 0.0, None
for epoch in vary(1, 201):
mannequin.prepare(); choose.zero_grad()
out = mannequin(information.x, information.edge_index)
loss = F.cross_entropy(out[train_mask], y[train_mask])
loss.backward(); choose.step()
if epoch % 20 == 0:
va_acc, va_f1 = consider(val_mask)
if va_acc > best_val:
best_val, best_state = va_acc, {ok: v.clone() for ok, v in mannequin.state_dict().gadgets()}
print(f"epoch {epoch:3d} | loss {loss.merchandise():.3f} | val_acc {va_acc:.3f} | val_f1 {va_f1:.3f}")
if best_state: mannequin.load_state_dict(best_state)
te_acc, te_f1 = consider(test_mask)
print(f"nTEST accuracy={te_acc:.3f} macro-F1={te_f1:.3f}")
We break up the graph nodes into coaching, validation, and take a look at masks so the mannequin can be taught and be evaluated correctly. We outline a two-layer GraphSAGE mannequin that learns node representations from each node options and graph construction. We prepare the mannequin for 200 epochs, monitor validation accuracy and macro-F1, save the very best mannequin state, and lastly report take a look at efficiency.
Visualizing Embeddings and Running a Heterogeneous GNN Forward Pass
mannequin.eval()
with torch.no_grad():
logits, emb = mannequin(information.x, information.edge_index, return_emb=True)
pred = logits.argmax(1).numpy()
emb2d = PCA(n_components=2).fit_transform(emb.numpy())
fig, axes = plt.subplots(1, 2, figsize=(15, 6))
for cls in vary(num_classes):
m = y.numpy() == cls
axes[0].scatter(emb2d[m, 0], emb2d[m, 1], s=10, label=CLASS_NAMES[cls], alpha=0.7)
axes[0].set_title("GraphSAGE node embeddings (PCA), colored by TRUE class")
axes[0].legend(fontsize=8); axes[0].set_xticks([]); axes[0].set_yticks([])
plot_gdf = nodes.copy(); plot_gdf["pred"] = pred
plot_gdf["pred_name"] = [CLASS_NAMES[p] for p in pred]
plot_gdf.plot(ax=axes[1], column="pred_name", legend=True, markersize=12, cmap="tab10")
axes[1].set_title("Predicted city perform (mapped again to geography)")
axes[1].set_axis_off()
strive:
import contextily as ctx
ctx.add_basemap(axes[1], crs=plot_gdf.crs, supply=ctx.suppliers.CartoDB.Positron)
besides Exception:
cross
plt.tight_layout(); plt.present()
if hetero is just not None:
strive:
for nt in hetero.node_types:
hetero[nt].x = hetero[nt].x.float()
class HGNN(torch.nn.Module):
def __init__(self, hid, out):
tremendous().__init__()
self.c1 = SAGEConv((-1, -1), hid)
self.c2 = SAGEConv((-1, -1), out)
def ahead(self, x, ei):
x = {ok: F.relu(v) for ok, v in self.c1(x, ei).gadgets()}
return self.c2(x, ei)
hmodel = to_hetero(HGNN(32, 16), hetero.metadata(), aggr="sum")
out_dict = hmodel(hetero.x_dict, hetero.edge_index_dict)
print("nHeterogeneous GNN output embedding shapes:")
for nt, t in out_dict.gadgets():
print(f" {nt}: {tuple(t.form)}")
besides Exception as e:
print("Hetero GNN ahead skipped:", e)
print("n
Done — proximity comparability, hetero building, and a skilled spatial GNN.")
We use the skilled GraphSAGE mannequin to extract node embeddings and predictions from the homogeneous graph. We scale back the discovered embeddings with PCA and visualize them alongside a geographic prediction map to grasp how the mannequin separates city capabilities. We additionally run a heterogeneous GNN ahead cross with to_hetero, exhibiting that the tutorial helps each homogeneous coaching and heterogeneous graph experimentation.
Key Takeaways
- city2graph turns uncooked OpenStreetMap POI and avenue information into spatial graphs.
- Six proximity graph households (KNN, Delaunay, Gabriel, RNG, EMST, Waxman) join the identical POIs otherwise.
- A artificial clustered fallback retains the workflow runnable with out OSM entry.
- A two-layer GraphSAGE mannequin predicts city perform classes from spatial construction.
- The pipeline helps each homogeneous coaching and heterogeneous graph experimentation by way of to_hetero.
Conclusion
In conclusion, we accomplished a full spatial GNN pipeline that transforms uncooked metropolis information into graph-based studying and visualization. We in contrast a number of proximity graph strategies, constructed a heterogeneous multi-layer graph, skilled a homogeneous GraphSAGE classifier, and inspected the discovered embeddings and geographic predictions. It offers us a sensible understanding of how spatial relationships amongst POIs may be represented as graph constructions and used to foretell city capabilities. It additionally exhibits how city2graph, GeoPandas, OSMnx, and PyTorch Geometric work collectively to help superior geospatial machine studying experiments in a Colab-friendly setup.
Check out the Full Codes with Notebook here. Also, be happy to comply with us on Twitter and don’t overlook to hitch our 150k+ML SubReddit and Subscribe to our Newsletter. Wait! are you on telegram? now you can join us on telegram as well.
Need to associate with us for selling your GitHub Repo OR Hugging Face Page OR Product Release OR Webinar and many others.? Connect with us
The publish A Coding Implementation on Spatial Graph Neural Networks for Urban Function Inference Using city2graph, OSMnx, and PyTorch Geometric appeared first on MarkTechPost.
