Comments (7)
This is a novel and intriguing method: training an Energy-Based Model (EBM) as a Generalised Additive Model (GAM) inside a huge CNN or Transformer architecture. While training classic EBMs usually involves end-to-end optimisation techniques, it is possible to modify them to operate within a broader neural network architecture and train them incrementally (batch-by-batch).
Here is a code example of how you can achieve this:
#import necessary libries
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
#define neural network architecture with an EBM layer
class CNNWithEBM(nn.Module):
def init(self):
super(CNNWithEBM, self).init()
self.cnn = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.ebm = nn.Linear(32 * 14 * 14, 1) # Example linear EBM layer
def forward(self, x):
x = self.cnn(x)
x = x.view(x.size(0), -1) # Flatten the output
energy = self.ebm(x)
return energy
#obtain synthetic dataset and define training loop
Generate synthetic dataset
def generate_data(batch_size=32):
# Generate random data and labels
data = torch.randn(batch_size, 1, 28, 28) # MNIST-like data
labels = torch.randint(0, 2, (batch_size,))
return data, labels
Instantiate the model
model = CNNWithEBM()
Define loss function (energy-based loss)
criterion = nn.MSELoss()
Define optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)
Training loop
num_epochs = 10
batch_size = 32
for epoch in range(num_epochs):
total_loss = 0.0
for batch_idx in range(num_batches):
# Generate mini-batch data
data, labels = generate_data(batch_size)
# Zero the gradients
optimizer.zero_grad()
# Forward pass
energy = model(data)
# Compute loss
loss = criterion(energy.squeeze(), labels.float()) # Energy-based loss
# Backward pass
loss.backward()
# Update parameters
optimizer.step()
# Accumulate total loss
total_loss += loss.item()
# Print average loss for the epoch
print(f"Epoch {epoch + 1}, Avg. Loss: {total_loss / num_batches:.4f}")
I hope that this helps.
Thank you
from interpret.
Is it possible to see the code you are working on to see how I can contribute more
from interpret.
Hi @JWKKWJ123 -- This kind of federated learning approach isn't something that we support out of the box. You can kind of hack it as you've discovered using merge_ebms, but the implementation isn't ideal. At some point we'll provide a better interface for building EBMs one boosting round at a time, and from batches.
Your other point though about DNNs and EBMs (based on decision trees) is quite pertinent too though. The training strategies are quite different and it's not clear to me that bringing them together will result in an ideal union. An alternative approach that I might suggest would be to train the DNN as normal, then remove the last layer, and train the EBM to replace it on the now frozen DNN. Will this approach work for you?
from interpret.
Dear Sunnycasmir,
Thank you for your reply!
This is a novel and intriguing method: training an Energy-Based Model (EBM) as a Generalised Additive Model (GAM) inside a huge CNN or Transformer architecture. While training classic EBMs usually involves end-to-end optimisation techniques, it is possible to modify them to operate within a broader neural network architecture and train them incrementally (batch-by-batch).
Here is a code example of how you can achieve this:
#import necessary libries import torch import torch.nn as nn import torch.optim as optim import numpy as np
#define neural network architecture with an EBM layer class CNNWithEBM(nn.Module): def init(self): super(CNNWithEBM, self).init() self.cnn = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2) ) self.ebm = nn.Linear(32 * 14 * 14, 1) # Example linear EBM layer
def forward(self, x): x = self.cnn(x) x = x.view(x.size(0), -1) # Flatten the output energy = self.ebm(x) return energy
#obtain synthetic dataset and define training loop
Generate synthetic dataset
def generate_data(batch_size=32): # Generate random data and labels data = torch.randn(batch_size, 1, 28, 28) # MNIST-like data labels = torch.randint(0, 2, (batch_size,)) return data, labels
Instantiate the model
model = CNNWithEBM()
Define loss function (energy-based loss)
criterion = nn.MSELoss()
Define optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)
Training loop
num_epochs = 10 batch_size = 32 for epoch in range(num_epochs): total_loss = 0.0 for batch_idx in range(num_batches): # Generate mini-batch data data, labels = generate_data(batch_size)
# Zero the gradients optimizer.zero_grad() # Forward pass energy = model(data) # Compute loss loss = criterion(energy.squeeze(), labels.float()) # Energy-based loss # Backward pass loss.backward() # Update parameters optimizer.step() # Accumulate total loss total_loss += loss.item() # Print average loss for the epoch print(f"Epoch {epoch + 1}, Avg. Loss: {total_loss / num_batches:.4f}")
I hope that this helps. Thank you
Dear Sunnycasmir,
Thank you very much for your reply!
More specifically, I want to use EBM (explainable boosting machine) as the output layer of a large CNN/transformer. I considered using EBM as a custom layer of torch, but this would make EBM untrainable. So my question is how to train EBM incrementally (batch-by-batch) as a custom layer of torch? I think the example code didn't solve this question.
from interpret.
Hi all,
I have some update this week:
I think the main difficulty is the deep-learning models and GAMs (including EBM) have very different training strategies. The GAMs need to read all training data at once and update the weights of all shape functions in the residuals sequentially. And the deep-learning models need to take the training data in mini-batch because of the memory limit (I use batchsize of 4 now), and update the model step by step.
I would like to use the EBM as the output block in a large end-to-end 3D CNN. Then the question will be: Can the EBM be progressively updated step by step (mini-batch by mini-batch) simultaneously with CNN?
I am trying to use the ebm.merge() to train the EBM in batchs and it seems work with a large batch.
This is the code that I put EBM in to a deep learning model, now I made EBM untrainable in a CNN, because I am going to alternatively train EBM and CNN:
class EBM_layer(nn.Module):
def __init__(self, **kwargs):
super(EBM_layer, self).__init__(**kwargs)
def forward(self, x, ebm):
x = x.detach().cpu().numpy()
output_pro_ebm = ebm.predict_proba(x)
output_pro_ebm = output_pro_ebm[:,1]
output_pro_ebm = torch.tensor(output_pro_ebm, requires_grad=True)
output_pro_ebm = output_pro_ebm.unsqueeze(1)
return output_pro_ebm
def forward(self, x,ebm): #now I train EBM and CNN alternatively, so I input a trained ebm to the model in each epoch
for i in range(0,N):
out = self.cnnlist[i](x)
out_all=torch.cat([out_all,out],1) #this is the concatenation of the feature extracted by multiple CNNs
out_pro = self.EBM_layer(out_all,ebm)
return out_pro
from interpret.
Hi @JWKKWJ123 -- This kind of federated learning approach isn't something that we support out of the box. You can kind of hack it as you've discovered using merge_ebms, but the implementation isn't ideal. At some point we'll provide a better interface for building EBMs one boosting round at a time, and from batches.
Your other point though about DNNs and EBMs (based on decision trees) is quite pertinent too though. The training strategies are quite different and it's not clear to me that bringing them together will result in an ideal union. An alternative approach that I might suggest would be to train the DNN as normal, then remove the last layer, and train the EBM to replace it on the now frozen DNN. Will this approach work for you?
Dear Paul,
Thank you very much for your reply! I'm glad I've made some progress now.
I found it is possible to use merge.ebm() to train ebm in batch with DNN. But now I am using a huge DNN so I can just set the batchsize to 4, and these training strategy cannot work when batchsize < 10.
So after trails and errors, I developed a new training strategy (figure below), which is the train the model alternatively in two stages. Now I use is in a case that take both take whole image (global) and image patches (local) as input, each path way in the end-to-end model is a CNN:
from interpret.
This training strategy works (I accidentally added the accuracy twice in the epoch between the two stages). It can provide the contributions of different pathways in a large composite DNN, without sacrificing performance:
from interpret.
Related Issues (20)
- How to get word importance HOT 1
- Development installation: Requirements? HOT 2
- Query: performance prospects on massive data sets (curse of dimensionality?) HOT 3
- How to speed up EBM model? Unbelievable slow. HOT 9
- Question: Parallel boosting? HOT 4
- Visualising Decision Tree explainer gives a Cytoscape object which is not savable to my local machine HOT 2
- [DP-EBM] Question regarding range R and sensitivity
- Support for more parameters in the Differentially Private models HOT 1
- NAM Model HOT 1
- Some hyperparameter questions HOT 3
- Lookup Table for single feature and feature interaction terms HOT 3
- Operations when merging EBM HOT 6
- EBM Classifier Global Feature Importance x Random Forest Classifier with Morris Sensitivity Analysis HOT 1
- possibility of adding `sample_weight` to `interpret.glassbox.ClassificationTree` HOT 6
- 2d PDP Z-axis colours appear too similar HOT 1
- Exporting EBM as PMML HOT 3
- Feature Request: Passing Validation Set or Index HOT 2
- Explore the data with continuous output and category input HOT 4
- Using the init_score in EBM Classifier HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from interpret.