Coder Social home page Coder Social logo

Comments (7)

sunnycasmir avatar sunnycasmir commented on May 29, 2024 1

This is a novel and intriguing method: training an Energy-Based Model (EBM) as a Generalised Additive Model (GAM) inside a huge CNN or Transformer architecture. While training classic EBMs usually involves end-to-end optimisation techniques, it is possible to modify them to operate within a broader neural network architecture and train them incrementally (batch-by-batch).

Here is a code example of how you can achieve this:

#import necessary libries
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

#define neural network architecture with an EBM layer
class CNNWithEBM(nn.Module):
def init(self):
super(CNNWithEBM, self).init()
self.cnn = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.ebm = nn.Linear(32 * 14 * 14, 1) # Example linear EBM layer

def forward(self, x):
    x = self.cnn(x)
    x = x.view(x.size(0), -1)  # Flatten the output
    energy = self.ebm(x)
    return energy

#obtain synthetic dataset and define training loop

Generate synthetic dataset

def generate_data(batch_size=32):
# Generate random data and labels
data = torch.randn(batch_size, 1, 28, 28) # MNIST-like data
labels = torch.randint(0, 2, (batch_size,))
return data, labels

Instantiate the model

model = CNNWithEBM()

Define loss function (energy-based loss)

criterion = nn.MSELoss()

Define optimizer

optimizer = optim.Adam(model.parameters(), lr=0.001)

Training loop

num_epochs = 10
batch_size = 32
for epoch in range(num_epochs):
total_loss = 0.0
for batch_idx in range(num_batches):
# Generate mini-batch data
data, labels = generate_data(batch_size)

    # Zero the gradients
    optimizer.zero_grad()

    # Forward pass
    energy = model(data)

    # Compute loss
    loss = criterion(energy.squeeze(), labels.float())  # Energy-based loss

    # Backward pass
    loss.backward()

    # Update parameters
    optimizer.step()

    # Accumulate total loss
    total_loss += loss.item()

# Print average loss for the epoch
print(f"Epoch {epoch + 1}, Avg. Loss: {total_loss / num_batches:.4f}")

I hope that this helps.
Thank you

from interpret.

sunnycasmir avatar sunnycasmir commented on May 29, 2024 1

Is it possible to see the code you are working on to see how I can contribute more

from interpret.

paulbkoch avatar paulbkoch commented on May 29, 2024 1

Hi @JWKKWJ123 -- This kind of federated learning approach isn't something that we support out of the box. You can kind of hack it as you've discovered using merge_ebms, but the implementation isn't ideal. At some point we'll provide a better interface for building EBMs one boosting round at a time, and from batches.

Your other point though about DNNs and EBMs (based on decision trees) is quite pertinent too though. The training strategies are quite different and it's not clear to me that bringing them together will result in an ideal union. An alternative approach that I might suggest would be to train the DNN as normal, then remove the last layer, and train the EBM to replace it on the now frozen DNN. Will this approach work for you?

from interpret.

JWKKWJ123 avatar JWKKWJ123 commented on May 29, 2024

Dear Sunnycasmir,
Thank you for your reply!

This is a novel and intriguing method: training an Energy-Based Model (EBM) as a Generalised Additive Model (GAM) inside a huge CNN or Transformer architecture. While training classic EBMs usually involves end-to-end optimisation techniques, it is possible to modify them to operate within a broader neural network architecture and train them incrementally (batch-by-batch).

Here is a code example of how you can achieve this:

#import necessary libries import torch import torch.nn as nn import torch.optim as optim import numpy as np

#define neural network architecture with an EBM layer class CNNWithEBM(nn.Module): def init(self): super(CNNWithEBM, self).init() self.cnn = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2) ) self.ebm = nn.Linear(32 * 14 * 14, 1) # Example linear EBM layer

def forward(self, x):
    x = self.cnn(x)
    x = x.view(x.size(0), -1)  # Flatten the output
    energy = self.ebm(x)
    return energy

#obtain synthetic dataset and define training loop

Generate synthetic dataset

def generate_data(batch_size=32): # Generate random data and labels data = torch.randn(batch_size, 1, 28, 28) # MNIST-like data labels = torch.randint(0, 2, (batch_size,)) return data, labels

Instantiate the model

model = CNNWithEBM()

Define loss function (energy-based loss)

criterion = nn.MSELoss()

Define optimizer

optimizer = optim.Adam(model.parameters(), lr=0.001)

Training loop

num_epochs = 10 batch_size = 32 for epoch in range(num_epochs): total_loss = 0.0 for batch_idx in range(num_batches): # Generate mini-batch data data, labels = generate_data(batch_size)

    # Zero the gradients
    optimizer.zero_grad()

    # Forward pass
    energy = model(data)

    # Compute loss
    loss = criterion(energy.squeeze(), labels.float())  # Energy-based loss

    # Backward pass
    loss.backward()

    # Update parameters
    optimizer.step()

    # Accumulate total loss
    total_loss += loss.item()

# Print average loss for the epoch
print(f"Epoch {epoch + 1}, Avg. Loss: {total_loss / num_batches:.4f}")

I hope that this helps. Thank you

Dear Sunnycasmir,
Thank you very much for your reply!
More specifically, I want to use EBM (explainable boosting machine) as the output layer of a large CNN/transformer. I considered using EBM as a custom layer of torch, but this would make EBM untrainable. So my question is how to train EBM incrementally (batch-by-batch) as a custom layer of torch? I think the example code didn't solve this question.

from interpret.

JWKKWJ123 avatar JWKKWJ123 commented on May 29, 2024

Hi all,
I have some update this week:
I think the main difficulty is the deep-learning models and GAMs (including EBM) have very different training strategies. The GAMs need to read all training data at once and update the weights of all shape functions in the residuals sequentially. And the deep-learning models need to take the training data in mini-batch because of the memory limit (I use batchsize of 4 now), and update the model step by step.
I would like to use the EBM as the output block in a large end-to-end 3D CNN. Then the question will be: Can the EBM be progressively updated step by step (mini-batch by mini-batch) simultaneously with CNN?
I am trying to use the ebm.merge() to train the EBM in batchs and it seems work with a large batch.
This is the code that I put EBM in to a deep learning model, now I made EBM untrainable in a CNN, because I am going to alternatively train EBM and CNN:

class EBM_layer(nn.Module):
    def __init__(self, **kwargs):
        super(EBM_layer, self).__init__(**kwargs)

    def forward(self, x, ebm):
    
        x = x.detach().cpu().numpy()
        output_pro_ebm = ebm.predict_proba(x)
        output_pro_ebm = output_pro_ebm[:,1]
        output_pro_ebm = torch.tensor(output_pro_ebm, requires_grad=True)
        output_pro_ebm = output_pro_ebm.unsqueeze(1)

        return output_pro_ebm
def forward(self, x,ebm): #now I train EBM and CNN alternatively, so I input a trained ebm to the model in each epoch
    for i in range(0,N):
       out = self.cnnlist[i](x)
       out_all=torch.cat([out_all,out],1) #this is the concatenation of the feature extracted by multiple CNNs
    out_pro = self.EBM_layer(out_all,ebm)

    return out_pro

from interpret.

JWKKWJ123 avatar JWKKWJ123 commented on May 29, 2024

Hi @JWKKWJ123 -- This kind of federated learning approach isn't something that we support out of the box. You can kind of hack it as you've discovered using merge_ebms, but the implementation isn't ideal. At some point we'll provide a better interface for building EBMs one boosting round at a time, and from batches.

Your other point though about DNNs and EBMs (based on decision trees) is quite pertinent too though. The training strategies are quite different and it's not clear to me that bringing them together will result in an ideal union. An alternative approach that I might suggest would be to train the DNN as normal, then remove the last layer, and train the EBM to replace it on the now frozen DNN. Will this approach work for you?

Dear Paul,
Thank you very much for your reply! I'm glad I've made some progress now.
I found it is possible to use merge.ebm() to train ebm in batch with DNN. But now I am using a huge DNN so I can just set the batchsize to 4, and these training strategy cannot work when batchsize < 10.
So after trails and errors, I developed a new training strategy (figure below), which is the train the model alternatively in two stages. Now I use is in a case that take both take whole image (global) and image patches (local) as input, each path way in the end-to-end model is a CNN:
image

from interpret.

JWKKWJ123 avatar JWKKWJ123 commented on May 29, 2024

This training strategy works (I accidentally added the accuracy twice in the epoch between the two stages). It can provide the contributions of different pathways in a large composite DNN, without sacrificing performance:
image

from interpret.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.