Comments (8)
Looking at the code in fit it seems to me that the evaluation is run only at the end of each epoch and even if I set steps_per_epoch then the next epoch the elements will be taken from the beginning again, thus I'm not using the whole dataset. Is that correct or am I missing something?
Can you check that in practice with an example? I expect that you will continue drawing from the same dataset. It needs to be able to generate steps_per_epoch * epochs
batches. You may need to call .repeat()
on it to achieve it.
from keras.
Yes the point is that I don't want to use .repeat()
beacause my dataset is already big enough. I want to be able to run a single epoch (using the whole train dataset) but run the validation every N train steps and save the model if there's an improvement of the validation loss. Now, for the last part I know I can use ModelCheckpoint
callback but I'm not sure if I can use the built-in fit
method or not to achieve the first requirement.
This is the code for the custom callback that handle this: the main issue is that there are native callbacks that runs at the end of each epoch and not during the batch step (like the History
callback).
class ValidationEveryNBatches(keras.callbacks.Callback):
def __init__(self,
validation_data,
validation_batch_size,
validation_freq=1,
validation_steps=None,
validation_callbacks=None,
verbose='auto'):
super(ValidationEveryNBatches, self).__init__()
self._validation_freq = validation_freq
self._validation_batch_size = validation_batch_size
self._validation_steps = validation_steps
self._validation_callbacks = validation_callbacks
self._verbose = verbose
self._val_x, self._val_y, self._val_sample_weight = keras.utils.unpack_x_y_sample_weight(validation_data)
self._validations_count = 0
def on_batch_end(self, batch, logs=None):
if (batch + 1) % self._validation_freq == 0:
self._validations_count += 1
io_utils.print_msg('\n---------------------------------------------------------------------------------')
io_utils.print_msg(f'Running validation after processing batch {batch + 1}. '
f'Total validations runs: {self._validations_count}')
val_logs = self.model.evaluate(
x=self._val_x,
y=self._val_y,
sample_weight=self._val_sample_weight,
batch_size=self._validation_batch_size,
steps=self._validation_steps,
callbacks=self._validation_callbacks,
return_dict=True,
verbose=self._verbose
)
io_utils.print_msg('---------------------------------------------------------------------------------')
val_logs = {"val_" + name: val for name, val in val_logs.items()}
if logs:
logs.update(val_logs)
self.model.reset_metrics()
from keras.
Fair enough, doing this in your own callback is a good solution. Lets you customize it to do whatever you want.
from keras.
But how can I update the history in order to keep track also of the training loss? Because at the moment the history for that will be saved at the end of each epoch that in my case it's only one
from keras.
Just make your callback create & update its own metrics/loss dict?
from keras.
Ok, not sure if I got it but I'll try that thanks. I was hoping that there was a way to do this by using the standard loop in the fit
method. At the end what I want it's like running N epochs but reading the training dataset from a certain index and not from the beginning at each epoch.
from keras.
The issue can be closed because it is not really an issue. Setting the seed and deterministic=True
while loading the dataset with tf.data
API helped in understing how Keras works and to achieve the desired result.
Keras support validation every N batches natively by using steps_per_epoch
just be sure to call .repeat()
on the dataset if necessary, as stated by @fchollet above.
The custom callback is not necessary.
Thanks for your help in clarifying this.
from keras.
Are you satisfied with the resolution of your issue?
Yes
No
from keras.
Related Issues (20)
- JAX array conversion failure in Keras model prediction HOT 3
- On JAX, Keras replaces any exception inside `call` method of `keras.Model` subclass with misleading error HOT 2
- To Keras community: What interpretations do you have for these curves? HOT 3
- No module named 'keras.src.engine' HOT 7
- Feature request: keras.ops.linalg.lstsq HOT 4
- Example Doubt HOT 3
- More Customisation in utils.ProgBar HOT 6
- Progress bar crash when empty dataset HOT 1
- Multihead Attention Seed Specification HOT 1
- Unable to make two instances of the MobileNetV3 within the same model HOT 2
- NumPy 2.0 support HOT 3
- Add backend-agnostic worker-process data loading HOT 8
- Keras does not save weights properly HOT 2
- Potential bug in legacy h5 weights loading. HOT 2
- Enable Discussions Tab in Github HOT 1
- FeatureSpace multiple output from one input HOT 3
- `keras.Sequential` sometimes states misleading reason for failing to construct model HOT 2
- Implement tool for saved Keras model file inspection, diff, and patching. HOT 5
- Request for a map function like map_fn in TF and vmap in Jax HOT 5
- AttributeError raised: 'list' object has no attribute 'dtype' when running the official example of SparseCategoricalAccuracy, TopKCategoricalAccuracy, SparseTopKCategoricalAccuracy HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from keras.