Coder Social home page Coder Social logo

Comments (3)

Modexus avatar Modexus commented on May 25, 2024

This may be a solution that only changes cast_storage of Image.
However, I'm not totally sure that the assumptions hold that are made about the ListArray.

elif pa.types.is_list(storage.type):
    from .features import Array3DExtensionType

    def get_shapes(arr):
        shape = ()
        while isinstance(arr, pa.ListArray):
            len_curr = len(arr)
            arr = arr.flatten()
            len_new = len(arr)
            shape = shape + (len_new // len_curr,)
        return shape

    def get_dtypes(arr):
        dtype = storage.type
        while hasattr(dtype, "value_type"):
            dtype = dtype.value_type
        return dtype

    arrays = []
    for i, is_null in enumerate(storage.is_null()):
        if not is_null.as_py():
            storage_part = storage.take([i])
            shape = get_shapes(storage_part)
            dtype = get_dtypes(storage_part)

            extension_type = Array3DExtensionType(shape=shape, dtype=str(dtype))
            array = pa.ExtensionArray.from_storage(extension_type, storage_part)
            arrays.append(array.to_numpy().squeeze(0))
        else:
            arrays.append(None)

    bytes_array = pa.array(
        [encode_np_array(arr)["bytes"] if arr is not None else None for arr in arrays],
        type=pa.binary(),
    )
    path_array = pa.array([None] * len(storage), type=pa.string())
    storage = pa.StructArray.from_arrays(
        [bytes_array, path_array], ["bytes", "path"], mask=bytes_array.is_null()
    )

(Edited): to handle nulls

Notably this doesn't change anything about the passing through of data or other things, just in the Image class.
Seems quite fast:

Fri Apr  5 17:55:51 2024    restats

         63818 function calls (61995 primitive calls) in 0.812 seconds

   Ordered by: cumulative time
   List reduced from 1051 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     47/1    0.000    0.000    0.810    0.810 {built-in method builtins.exec}
      2/1    0.000    0.000    0.810    0.810 <string>:1(<module>)
      2/1    0.000    0.000    0.809    0.809 arrow_dataset.py:594(wrapper)
      2/1    0.000    0.000    0.809    0.809 arrow_dataset.py:551(wrapper)
      2/1    0.000    0.000    0.809    0.809 arrow_dataset.py:2916(map)
        3    0.000    0.000    0.807    0.269 arrow_dataset.py:3277(_map_single)
        1    0.000    0.000    0.760    0.760 arrow_writer.py:589(finalize)
        1    0.000    0.000    0.760    0.760 arrow_writer.py:423(write_examples_on_file)
        1    0.000    0.000    0.759    0.759 arrow_writer.py:527(write_batch)
        1    0.001    0.001    0.754    0.754 arrow_writer.py:161(__arrow_array__)
      2/1    0.000    0.000    0.719    0.719 table.py:1800(wrapper)
        1    0.000    0.000    0.719    0.719 table.py:1950(cast_array_to_feature)
        1    0.006    0.006    0.718    0.718 image.py:209(cast_storage)
        1    0.000    0.000    0.451    0.451 image.py:361(encode_np_array)
        1    0.000    0.000    0.444    0.444 image.py:343(image_to_bytes)
        1    0.000    0.000    0.413    0.413 Image.py:2376(save)
        1    0.000    0.000    0.413    0.413 PngImagePlugin.py:1233(_save)
        1    0.000    0.000    0.413    0.413 ImageFile.py:517(_save)
        1    0.000    0.000    0.413    0.413 ImageFile.py:545(_encode_tile)
      397    0.409    0.001    0.409    0.001 {method 'encode' of 'ImagingEncoder' objects}

from datasets.

jdf-prog avatar jdf-prog commented on May 25, 2024

Also encounter this problem. Has been strugging with it for a long time...

from datasets.

Modexus avatar Modexus commented on May 25, 2024

This actually applies to all arrays (numpy or tensors like in torch), not only from external files.

import numpy as np
import datasets

ds = datasets.Dataset.from_dict(
    {"image": [np.random.randint(0, 255, (2048, 2048, 3), dtype=np.uint8)]},
    features=datasets.Features({"image": datasets.Image(decode=True)}),
)
ds.set_format("numpy")

ds = ds.map(load_from_cache_file=False)

from datasets.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.