Coder Social home page Coder Social logo

Comments (6)

SuryanarayanaY avatar SuryanarayanaY commented on September 22, 2024

Hi @natbprice ,

Thanks for reporting. I have tested the code snippet and reproduced the reported behaviour. Attached gist for reference.

from keras.

hertschuh avatar hertschuh commented on September 22, 2024

@natbprice ,

Thanks for the report and the investigation. After looking into it in details, I came to the conclusion that this works as expected.

I saw your proposed fix:

  if isinstance(ds, _MapDataset) or isinstance(ds, _ParallelMapDataset):
    return ds._input_dataset._batch_size

But that's the batch size of the input dataset. The issue is that there is no constraint on what the function passed to map is allowed to do, therefore there is no guarantee that what comes out of map has the same batch size as what came in.

Now, why does this only happen when using multi-process distribution? That's because Keras is able to train with an unknown batch size in the normal case and only tries to determine the batch size if distribution is turned on.

What's the fix? Well, the standard pattern I've seen used is to batch last, after map, shuffle etc.

ds = tf.data.Dataset.from_tensor_slices((inputs, labels))
ds = ds.map(lambda x,y: (x,y))
ds = ds.batch(16)

Let me know if you have further questions.

from keras.

google-ml-butler avatar google-ml-butler commented on September 22, 2024

Are you satisfied with the resolution of your issue?
Yes
No

from keras.

natbprice avatar natbprice commented on September 22, 2024

@hertschuh thanks for investigating this. Based on your conclusion, it sounds like this issue should instead be resolved in keras-team/keras-nlp#1630? In that case, a preprocessor is being mapped over the data internally so there doesn't appear to be an easy workaround.

Sorry, if I created extra work. I guess I should have not opened related issue here.

from keras.

hertschuh avatar hertschuh commented on September 22, 2024

@hertschuh thanks for investigating this. Based on your conclusion, it sounds like this issue should instead be resolved in keras-team/keras-nlp#1630? In that case, a preprocessor is being mapped over the data internally so there doesn't appear to be an easy workaround.

Sorry, if I created extra work. I guess I should have not opened related issue here.

@natbprice ,

Yes, I think the fix should be in keras-nlp. One should simply apply batch_size after the map and not in _convert_inputs_to_dataset. Do you want me to follow up in the keras-nlp bug?

from keras.

natbprice avatar natbprice commented on September 22, 2024

@hertschuh if you don't mind following up in keras-nlp, that would be great! I think I understand the solution you are proposing, but I can't quite figure out the best way for keras-nlp API to function. In particular, it seems like there are several combinations of (1) distribution strategy, (2) input types (e.g., tf.data.Dataset, NumPy arrays), and (3) batching (e.g., pre-batched dataset, explicit batch_size).

Currently, in _convert_inputs_to_dataset it will raise an error if you attempt to pass a tf.data.Dataset with explicit batch_size argument. It also looks like there is error handling to prevent you from passing unbatched inputs, but the string matching on the error message may be oudated and not functioning.

from keras.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.