Coder Social home page Coder Social logo

Comments (12)

qubvel avatar qubvel commented on May 19, 2024 1

According to this and this issues it can be implemented as follows:

def set_regularization(model, 
                       kernel_regularizer=None, 
                       bias_regularizer=None, 
                       activity_regularizer=None):
    
    for layer in model.layers:
        
        # set kernel_regularizer
        if kernel_regularizer is not None and hasattr(layer, 'kernel_regularizer'):
            layer.kernel_regularizer = kernel_regularizer

        # set bias_regularizer
        if bias_regularizer is not None and hasattr(layer, 'bias_regularizer'):
            layer.bias_regularizer = bias_regularizer

        # set activity_regularizer
        if activity_regularizer is not None and hasattr(layer, 'activity_regularizer'):
            layer.activity_regularizer = activity_regularizer

# exmaple
set_regularization(model, kernel_regularizer=keras.regularizers.l2(0.0001))
model.compile(...)  # you have to recompile model if regularization is changed

I did not test this code, if it works it can be added as utils function.

from segmentation_models.

qubvel avatar qubvel commented on May 19, 2024

Hi, @Tyler-D
Did you mean a possibility to add regularisation for all convolution layers of the model?

from segmentation_models.

Tyler-D avatar Tyler-D commented on May 19, 2024

Well, I think it would be better if there is a function that adding specific regularizer to all layers.

from segmentation_models.

Tyler-D avatar Tyler-D commented on May 19, 2024

Cool, that's exactly the function I want. I could help to add it, what kind of test you needed?

from segmentation_models.

Tyler-D avatar Tyler-D commented on May 19, 2024

Actually, I'm thinking if there is possibility to build a segmentation task pipeline upon your repo including: train, evaluation, some data-loader for public dataset (e.g. pascal-voc, coco) and even an export tool to export the keras model to inference framework (e.g TensorRT). Then I'm sure this repository can be extremely appealing.

from segmentation_models.

qubvel avatar qubvel commented on May 19, 2024

Just test that it works as expected:

Regularization appears in conv/dense layers and applied during training.
Saved/loaded model has regularization.

from segmentation_models.

qubvel avatar qubvel commented on May 19, 2024

Segmentation pipeline is a cool idea, however I think it should be build in other repo or written as an example part here.
If you can recommend any cool repos with such kind of pipeline it would be extremly helpful! πŸ˜„

from segmentation_models.

Tyler-D avatar Tyler-D commented on May 19, 2024

I've tried the code you offered in my train scripts and thing is that only the model config is changed. And after investigation, I found this. And a workround can be found here:

def create_model():
    model = your_model()
    model.save_weights("tmp.h5")

    # optionally do some other modifications (freezing layers, adding convolutions etc.)
    ....

    regularizer = l2(WEIGHT_DECAY / 2)
    for layer in model.layers:
        for attr in ['kernel_regularizer', 'bias_regularizer']:
            if hasattr(layer, attr) and layer.trainable:
                setattr(layer, attr, regularizer)

    out = model_from_json(model.to_json())
    out.load_weights("tmp.h5", by_name=True)

    return  out

It seems not an elegant way to do the things. I'm thinking how to refactor it.

from segmentation_models.

qubvel avatar qubvel commented on May 19, 2024

Yes, I agree. this is not elegant way..
Another not elegant way, but at least do not require model saving:

def set_regularization(model, 
                       kernel_regularizer=None, 
                       bias_regularizer=None, 
                       activity_regularizer=None):
    
    for layer in model.layers:
        
        # set kernel_regularizer
        if kernel_regularizer is not None and hasattr(layer, 'kernel_regularizer'):
            layer.kernel_regularizer = kernel_regularizer

        # set bias_regularizer
        if bias_regularizer is not None and hasattr(layer, 'bias_regularizer'):
            layer.bias_regularizer = bias_regularizer

        # set activity_regularizer
        if activity_regularizer is not None and hasattr(layer, 'activity_regularizer'):
            layer.activity_regularizer = activity_regularizer

    out = model_from_json(model.to_json())
    out.set_weights(model.get_weights())

    return out

new_model = set_regularization(model, kernel_regularizer=keras.regularizers.l2(0.0001))
new_model.compile(...) 

from segmentation_models.

Tyler-D avatar Tyler-D commented on May 19, 2024

Hi @qubvel . I've tested the new implementation, and it works well! You can add it #54 .

from segmentation_models.

qubvel avatar qubvel commented on May 19, 2024

Hi @Tyler-D, ok, no problem

from segmentation_models.

mathmanu avatar mathmanu commented on May 19, 2024

Try this:

# a utility function to add weight decay after the model is defined.
def add_weight_decay(model, weight_decay):
	if (weight_decay is None) or (weight_decay == 0.0):
		return

	# recursion inside the model
	def add_decay_loss(m, factor):
		if isinstance(m, tf.keras.Model):
			for layer in m.layers:
				add_decay_loss(layer, factor)
		else:
			for param in m.trainable_weights:
				with tf.keras.backend.name_scope('weight_regularizer'):
					regularizer = lambda: tf.keras.regularizers.l2(factor)(param)
					m.add_loss(regularizer)

	# weight decay and l2 regularization differs by a factor of 2
	add_decay_loss(model, weight_decay/2.0)
	return

from segmentation_models.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.