Coder Social home page Coder Social logo

Run validation every N batches about keras HOT 8 CLOSED

mpetteno avatar mpetteno commented on May 18, 2024
Run validation every N batches

from keras.

Comments (8)

fchollet avatar fchollet commented on May 18, 2024

Looking at the code in fit it seems to me that the evaluation is run only at the end of each epoch and even if I set steps_per_epoch then the next epoch the elements will be taken from the beginning again, thus I'm not using the whole dataset. Is that correct or am I missing something?

Can you check that in practice with an example? I expect that you will continue drawing from the same dataset. It needs to be able to generate steps_per_epoch * epochs batches. You may need to call .repeat() on it to achieve it.

from keras.

mpetteno avatar mpetteno commented on May 18, 2024

Yes the point is that I don't want to use .repeat() beacause my dataset is already big enough. I want to be able to run a single epoch (using the whole train dataset) but run the validation every N train steps and save the model if there's an improvement of the validation loss. Now, for the last part I know I can use ModelCheckpoint callback but I'm not sure if I can use the built-in fit method or not to achieve the first requirement.

This is the code for the custom callback that handle this: the main issue is that there are native callbacks that runs at the end of each epoch and not during the batch step (like the History callback).

class ValidationEveryNBatches(keras.callbacks.Callback):

    def __init__(self,
                 validation_data,
                 validation_batch_size,
                 validation_freq=1,
                 validation_steps=None,
                 validation_callbacks=None,
                 verbose='auto'):
        super(ValidationEveryNBatches, self).__init__()
        self._validation_freq = validation_freq
        self._validation_batch_size = validation_batch_size
        self._validation_steps = validation_steps
        self._validation_callbacks = validation_callbacks
        self._verbose = verbose
        self._val_x, self._val_y, self._val_sample_weight = keras.utils.unpack_x_y_sample_weight(validation_data)
        self._validations_count = 0

    def on_batch_end(self, batch, logs=None):
        if (batch + 1) % self._validation_freq == 0:
            self._validations_count += 1
            io_utils.print_msg('\n---------------------------------------------------------------------------------')
            io_utils.print_msg(f'Running validation after processing batch {batch + 1}. '
                               f'Total validations runs: {self._validations_count}')
            val_logs = self.model.evaluate(
                x=self._val_x,
                y=self._val_y,
                sample_weight=self._val_sample_weight,
                batch_size=self._validation_batch_size,
                steps=self._validation_steps,
                callbacks=self._validation_callbacks,
                return_dict=True,
                verbose=self._verbose
            )
            io_utils.print_msg('---------------------------------------------------------------------------------')
            val_logs = {"val_" + name: val for name, val in val_logs.items()}
            if logs:
                logs.update(val_logs)
            self.model.reset_metrics()

from keras.

fchollet avatar fchollet commented on May 18, 2024

Fair enough, doing this in your own callback is a good solution. Lets you customize it to do whatever you want.

from keras.

mpetteno avatar mpetteno commented on May 18, 2024

But how can I update the history in order to keep track also of the training loss? Because at the moment the history for that will be saved at the end of each epoch that in my case it's only one

from keras.

fchollet avatar fchollet commented on May 18, 2024

Just make your callback create & update its own metrics/loss dict?

from keras.

mpetteno avatar mpetteno commented on May 18, 2024

Ok, not sure if I got it but I'll try that thanks. I was hoping that there was a way to do this by using the standard loop in the fit method. At the end what I want it's like running N epochs but reading the training dataset from a certain index and not from the beginning at each epoch.

from keras.

mpetteno avatar mpetteno commented on May 18, 2024

The issue can be closed because it is not really an issue. Setting the seed and deterministic=True while loading the dataset with tf.data API helped in understing how Keras works and to achieve the desired result.
Keras support validation every N batches natively by using steps_per_epoch just be sure to call .repeat() on the dataset if necessary, as stated by @fchollet above.

The custom callback is not necessary.

Thanks for your help in clarifying this.

from keras.

google-ml-butler avatar google-ml-butler commented on May 18, 2024

Are you satisfied with the resolution of your issue?
Yes
No

from keras.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.