Comments (6)
Hi @natbprice ,
Thanks for reporting. I have tested the code snippet and reproduced the reported behaviour. Attached gist for reference.
from keras.
Thanks for the report and the investigation. After looking into it in details, I came to the conclusion that this works as expected.
I saw your proposed fix:
if isinstance(ds, _MapDataset) or isinstance(ds, _ParallelMapDataset):
return ds._input_dataset._batch_size
But that's the batch size of the input dataset. The issue is that there is no constraint on what the function passed to map
is allowed to do, therefore there is no guarantee that what comes out of map
has the same batch size as what came in.
Now, why does this only happen when using multi-process distribution? That's because Keras is able to train with an unknown batch size in the normal case and only tries to determine the batch size if distribution is turned on.
What's the fix? Well, the standard pattern I've seen used is to batch last, after map
, shuffle
etc.
ds = tf.data.Dataset.from_tensor_slices((inputs, labels))
ds = ds.map(lambda x,y: (x,y))
ds = ds.batch(16)
Let me know if you have further questions.
from keras.
Are you satisfied with the resolution of your issue?
Yes
No
from keras.
@hertschuh thanks for investigating this. Based on your conclusion, it sounds like this issue should instead be resolved in keras-team/keras-nlp#1630? In that case, a preprocessor is being mapped over the data internally so there doesn't appear to be an easy workaround.
Sorry, if I created extra work. I guess I should have not opened related issue here.
from keras.
@hertschuh thanks for investigating this. Based on your conclusion, it sounds like this issue should instead be resolved in keras-team/keras-nlp#1630? In that case, a preprocessor is being mapped over the data internally so there doesn't appear to be an easy workaround.
Sorry, if I created extra work. I guess I should have not opened related issue here.
Yes, I think the fix should be in keras-nlp. One should simply apply batch_size
after the map
and not in _convert_inputs_to_dataset
. Do you want me to follow up in the keras-nlp bug?
from keras.
@hertschuh if you don't mind following up in keras-nlp, that would be great! I think I understand the solution you are proposing, but I can't quite figure out the best way for keras-nlp API to function. In particular, it seems like there are several combinations of (1) distribution strategy, (2) input types (e.g., tf.data.Dataset, NumPy arrays), and (3) batching (e.g., pre-batched dataset, explicit batch_size
).
Currently, in _convert_inputs_to_dataset
it will raise an error if you attempt to pass a tf.data.Dataset
with explicit batch_size
argument. It also looks like there is error handling to prevent you from passing unbatched inputs, but the string matching on the error message may be oudated and not functioning.
from keras.
Related Issues (20)
- Tensorboard callback is blocking process HOT 2
- I'm afraid that the Bidirectional Wrapper will not work in Keras Functional Api. HOT 4
- Subclassing a model, writing a custom train_step, and distributed training in Tensorflow
- Possible JIT compilation bug with JAX HOT 3
- This works in keras 2.15 with TensorFlow 2.15 but does not 3.5 with TensorFlow 2.17: multiple lookup layer types in inputs HOT 1
- Keras model deserialization issue HOT 5
- How to make dynamic assetions in Keras v3? HOT 3
- Is there a keras 3 equivalent to serialization.DisableSharedObjectScope()? HOT 4
- layers.GRU returns wrong shaped output with GPU HOT 3
- Obscure validation failure due to `_use_cached_eval_dataset` HOT 5
- Wrong binary accuracy with Jax HOT 4
- tf.keras.datasets.cifar10.load_data - FileNotFoundError: [Errno 2] No such file or directory HOT 15
- fix: Densenet Documentation HOT 1
- Cannot get result() since the metric has not yet been built HOT 1
- Keras different versions have numerical deviations when using pretrain model HOT 1
- ops.image.affine_transform() does not work as a layer in GPU HOT 5
- Use Keras to load dataset HOT 3
- Nested sequentials broken in 3.5 HOT 1
- Deserializing nested objects (here: SeedGenerator as seed for GlorotUniform initializer) HOT 5
- Embedding Projector using TensorBoard callback HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from keras.