mims-harvard / graphxai Goto Github PK
View Code? Open in Web Editor NEWGraphXAI: Resource to support the development and evaluation of GNN explainers
Home Page: https://zitniklab.hms.harvard.edu/projects/GraphXAI
License: MIT License
GraphXAI: Resource to support the development and evaluation of GNN explainers
Home Page: https://zitniklab.hms.harvard.edu/projects/GraphXAI
License: MIT License
Thanks for your efforts! I had a quick question on the calculation for graph_exp_faith in the metrics.py file.
I just wanted to clarify why after the model results on all data was calculated and softmax applied, that the log was taken, but the log wasn’t applied after applying the softmax to the model results on the perturbed dataset.* Thanks for the clarification.
*1 - torch.exp(-F.kl_div(org_softmax.log(), pert_softmax, None, None, 'sum')).item()
I believe there is a significant error for PGMExplainer.
The comments for function 'PGM_perturb_node_features' (in GraphXAI/graphxai/utils/perturb/perturb.py) state that it should return:
x_pert (torch.Tensor, [n x d]): perturbed feature matrix
node_mask (torch.Tensor, [n]): Boolean mask of perturbed nodes
However:
Thank you again for your great work!
However, I encountered some confusing coding snippets in your implementation in the script, GraphXAI/graphxai/explainers/pg_explainer.py (Lines 119-123):
if self.explain_graph:
h = torch.cat([h1, h2], dim=1)
else:
h3 = emb.repeat(h1.shape[0], 1)
h = torch.cat([h1, h2], dim=1)
I believe that "self.explain_graph = True" corresponds to graph-level explanation, while "self.explain_graph = False" corresponds to node-level explanation. However, in your implementation (I know that you adapted this from the DIG library), the h3 = emb.repeat(h1.shape[0], 1)
is defined but never used. So how to distinguish graph-level explanation and node-level explanation?
First of all , Look my model code :
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
import networkx as nx
from sklearn.preprocessing import MinMaxScaler
from sklearn.neighbors import NearestNeighbors
from torch_geometric.data import Data
from torch_geometric.nn import SAGEConv
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay
from captum.attr import Saliency, IntegratedGradients
import random
excel_file_path = "/content/drive/MyDrive/GNN/chest_x_ray_dataset.xlsx"
df = pd.read_excel(excel_file_path)
df = df.fillna(df.mean())
X = df.drop('class', axis=1).values # Features
y = df['class'].values # Target variable
scaler = MinMaxScaler()
X_normalized = scaler.fit_transform(X)
X_log_transformed = np.log(X_normalized + 1) # Adding 1 to avoid log(0)
K = 10 # Number of nearest neighbors to consider (adjust as needed)
knn = NearestNeighbors(n_neighbors=K, algorithm='ball_tree')
knn.fit(X_log_transformed) # Use the log-transformed data for graph construction
knn_indices = knn.kneighbors(return_distance=False)
graph = nx.Graph()
for i in range(len(df)):
graph.add_node(i)
for i, neighbors in enumerate(knn_indices):
for neighbor in neighbors:
if i != neighbor:
graph.add_edge(i, neighbor)
labels = {i: label for i, label in enumerate(y)}
nx.set_node_attributes(graph, labels, 'label')
edge_index = torch.tensor(np.array(list(graph.edges())).T, dtype=torch.long)
x = torch.tensor(X_log_transformed, dtype=torch.float) # Use the log-transformed data
y = torch.tensor(y, dtype=torch.long)
data = Data(x=x, edge_index=edge_index, y=y)
class CustomGNN(torch.nn.Module):
def init(self, num_features, hidden_channels, num_classes):
super(CustomGNN, self).init()
self.conv1 = SAGEConv(num_features, hidden_channels)
self.conv2 = SAGEConv(hidden_channels, hidden_channels)
self.conv3 = SAGEConv(hidden_channels, hidden_channels)
self.conv4 = SAGEConv(hidden_channels, hidden_channels)
self.conv5 = SAGEConv(hidden_channels, hidden_channels)
self.conv6 = SAGEConv(hidden_channels, hidden_channels) # Additional layer
self.conv7 = SAGEConv(hidden_channels, hidden_channels) # Additional layer
self.conv8 = SAGEConv(hidden_channels, num_classes) # Adjust output layer
self.relu = nn.ReLU()
self.bn1 = nn.BatchNorm1d(hidden_channels)
self.bn2 = nn.BatchNorm1d(hidden_channels)
self.bn3 = nn.BatchNorm1d(hidden_channels)
self.bn4 = nn.BatchNorm1d(hidden_channels)
self.bn5 = nn.BatchNorm1d(hidden_channels)
self.bn6 = nn.BatchNorm1d(hidden_channels) # Additional layer
self.bn7 = nn.BatchNorm1d(hidden_channels) # Additional layer
self.dropout = nn.Dropout(0.3)
def forward(self, x, edge_index, batch):
x = self.conv1(x, edge_index)
x = self.bn1(x)
x = self.relu(x)
x = self.dropout(x)
x = self.conv2(x, edge_index)
x = self.bn2(x)
x = self.relu(x)
x = self.dropout(x)
x = self.conv3(x, edge_index)
x = self.bn3(x)
x = self.relu(x)
x = self.dropout(x)
x = self.conv4(x, edge_index)
x = self.bn4(x)
x = self.relu(x)
x = self.dropout(x)
x = self.conv5(x, edge_index)
x = self.bn5(x)
x = self.relu(x)
x = self.dropout(x)
x = self.conv6(x, edge_index) # Additional layer
x = self.bn6(x)
x = self.relu(x)
x = self.dropout(x)
x = self.conv7(x, edge_index) # Additional layer
x = self.bn7(x)
x = self.relu(x)
x = self.dropout(x)
x = self.conv8(x, edge_index, batch) # Adjust output layer
return x
best_hidden_channels = 256 # Replace with your best value
model = CustomGNN(num_features=X_log_transformed.shape[1], hidden_channels=best_hidden_channels, num_classes=6)
def weight_init(m):
if isinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight.data)
model.apply(weight_init)
best_lr = 0.005689229656484651 # Replace with your best value
optimizer = optim.Adam(model.parameters(), lr=best_lr)
criterion = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, verbose=True)
best_accuracy = 0.0
patience = 150
early_stopping_counter = 0
for epoch in range(350):
model.train()
optimizer.zero_grad()
# Provide the batch argument when calling the model
out = model(data.x, data.edge_index, data.batch)
loss = criterion(out, data.y)
loss.backward()
optimizer.step()
model.eval()
with torch.no_grad():
# Provide the batch argument when calling the model
out = model(data.x, data.edge_index, data.batch)
y_pred = out.argmax(dim=1)
accuracy = accuracy_score(data.y, y_pred)
scheduler.step(loss) # Adjust learning rate based on loss
print(f'Epoch {epoch}, Loss: {loss:.4f}, Accuracy: {accuracy:.4f}')
if accuracy > best_accuracy:
best_accuracy = accuracy
early_stopping_counter = 0
else:
early_stopping_counter += 1
if early_stopping_counter >= patience:
print("Early stopping")
break
print(f"Best Accuracy: {best_accuracy:.4f}")
Now i want to explain my prediction using your provided library but i can't implement it for custom dataset ( like my dataset) .Can you Provide me a simple code for it ??
How i can implement same as "vis_shapegraph.ipynb" ?
Lastly i face library installation issue like
ERROR: Could not find a version that satisfies the requirement torch-cluster (from versions: 0.1.1, 0.2.3, 0.2.4, 1.0.1, 1.0.3, 1.1.1, 1.1.2, 1.1.3, 1.1.4, 1.1.5, 1.2.1, 1.2.2, 1.2.3, 1.2.4, 1.3.0, 1.4.0, 1.4.1, 1.4.2, 1.4.3a1, 1.4.3, 1.4.4, 1.4.5, 1.5.2, 1.5.3, 1.5.4, 1.5.5, 1.5.6, 1.5.7, 1.5.8, 1.5.9, 1.6.0, 1.6.1)
ERROR: No matching distribution found for torch-cluster
Broken by the new PyG implementation - needs fixed.
Usage:
pip install [options] [package-index-options] ...
pip install [options] -r [package-index-options] ...
pip install [options] [-e] ...
pip install [options] [-e] ...
pip install [options] <archive url/path> ...
-e option requires 1 argument
Should be able to map between explanations on one graph.
Thanks for the great work. I have installed graphxai using github. It shows errors when I try to import graphxai in python. Please check details below:
File "<ipython-input-328-53dd77503f89>", line 1, in <module>
import graphxai
File "/opt/anaconda3/lib/python3.8/site-packages/graphxai/__init__.py", line 1, in <module>
from graphxai.utils.explanation import Explanation
ModuleNotFoundError: No module named 'graphxai.utils'
Thanks for this great work!
I was looking to run real world examples with MUTAG, and noticed that the formal/realworld/mutag/EXPS
directory was missing.
This directory is utilized by graph_eval.py and it searches for *.pt files inside.
How can I get this data?
Thanks for your efforts on this work!
I understood the examples you provided for the MUTAG TUDataset. I need to run explainers for Planetoid datasets. More specifically, I want to implement test_SubgraphX_Cora.py
, which is very similar to test_SubgraphX_MUTAG.py.
Can you provide a guideline for this purpose?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.