Coder Social home page Coder Social logo

deepmind / dsprites-dataset Goto Github PK

View Code? Open in Web Editor NEW
469.0 15.0 70.0 28.85 MB

Dataset to assess the disentanglement properties of unsupervised learning methods

License: Apache License 2.0

Jupyter Notebook 100.00%
vae beta-vae disentanglement dataset dsprites

dsprites-dataset's Introduction

dSprites - Disentanglement testing Sprites dataset

This repository contains the dSprites dataset, used to assess the disentanglement properties of unsupervised learning methods.

If you use this dataset in your work, please cite it as follows:

Bibtex

@misc{dsprites17,
author = {Loic Matthey and Irina Higgins and Demis Hassabis and Alexander Lerchner},
title = {dSprites: Disentanglement testing Sprites dataset},
howpublished= {https://github.com/deepmind/dsprites-dataset/},
year = "2017",
}

Description

dsprite_gif

dSprites is a dataset of 2D shapes procedurally generated from 6 ground truth independent latent factors. These factors are color, shape, scale, rotation, x and y positions of a sprite.

All possible combinations of these latents are present exactly once, generating N = 737280 total images.

Latent factor values

  • Color: white
  • Shape: square, ellipse, heart
  • Scale: 6 values linearly spaced in [0.5, 1]
  • Orientation: 40 values in [0, 2 pi]
  • Position X: 32 values in [0, 1]
  • Position Y: 32 values in [0, 1]

We varied one latent at a time (starting from Position Y, then Position X, etc), and sequentially stored the images in fixed order. Hence the order along the first dimension is fixed and allows you to map back to the value of the latents corresponding to that image.

We chose the latents values deliberately to have the smallest step changes while ensuring that all pixel outputs were different. No noise was added.

The data is a NPZ NumPy archive with the following fields:

  • imgs: (737280 x 64 x 64, uint8) Images in black and white.
  • latents_values: (737280 x 6, float64) Values of the latent factors.
  • latents_classes: (737280 x 6, int64) Integer index of the latent factor values. Useful as classification targets.
  • metadata: some additional information, including the possible latent values.

Alternatively, a HDF5 version is also available, containing the same data, packed as Groups and Datasets.

Disentanglement metric

This dataset was created as a unit test of disentanglement properties of unsupervised models. It can be used to determine how well models recover the ground truth latents presented above.

You find our proposed disentanglement metric assessing the disentanglement quality of a model (along with an example usage of this dataset) in:

Higgins, Irina, Loic Matthey, Arka Pal, Christopher Burgess, Xavier Glorot, Matthew Botvinick, Shakir Mohamed, and Alexander Lerchner. "beta-VAE: Learning basic visual concepts with a constrained variational framework." In Proceedings of the International Conference on Learning Representations (ICLR). 2017.

Disclaimers

This is not an official Google product.

The images were generated using the LOVE framework, which is licenced under zlib/libpng licence:

LOVE is Copyright (c) 2006-2016 LOVE Development Team

This software is provided 'as-is', without any express or implied
warranty. In no event will the authors be held liable for any damages
arising from the use of this software.

Permission is granted to anyone to use this software for any purpose,
including commercial applications, and to alter it and redistribute it
freely, subject to the following restrictions:

1. The origin of this software must not be misrepresented; you must not
claim that you wrote the original software. If you use this software
in a product, an acknowledgment in the product documentation would be
appreciated but is not required.

2. Altered source versions must be plainly marked as such, and must not be
misrepresented as being the original software.

3. This notice may not be removed or altered from any source
distribution.

dsprites-dataset's People

Contributors

azhag avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

dsprites-dataset's Issues

File Damage, dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz

I was running the example code and found errors at

# Load dataset
dataset_zip = np.load('dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz')

print('Keys in the dataset:', dataset_zip.keys())
imgs = dataset_zip['imgs']
latents_values = dataset_zip['latents_values']
latents_classes = dataset_zip['latents_classes']
metadata = dataset_zip['metadata'][()]

print('Metadata: \n', metadata)

Keys in the dataset: KeysView(<numpy.lib.npyio.NpzFile object at 0x000002275653D550>)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-2-6aadbe2135e4> in <module>
      6 latents_values = dataset_zip['latents_values']
      7 latents_classes = dataset_zip['latents_classes']
----> 8 metadata = dataset_zip['metadata'][()]
      9 
     10 print('Metadata: \n', metadata)

~\Anaconda3\envs\PyTorch\lib\site-packages\numpy\lib\npyio.py in __getitem__(self, key)
    258             if magic == format.MAGIC_PREFIX:
    259                 bytes = self.zip.open(key)
--> 260                 return format.read_array(bytes,
    261                                          allow_pickle=self.allow_pickle,
    262                                          pickle_kwargs=self.pickle_kwargs)

~\Anaconda3\envs\PyTorch\lib\site-packages\numpy\lib\format.py in read_array(fp, allow_pickle, pickle_kwargs)
    737         # The array contained Python objects. We need to unpickle the data.
    738         if not allow_pickle:
--> 739             raise ValueError("Object arrays cannot be loaded when "
    740                              "allow_pickle=False")
    741         if pickle_kwargs is None:

ValueError: Object arrays cannot be loaded when allow_pickle=False

If I added 'allow_pickle=True' at np.load, then the error changed to

Keys in the dataset: KeysView(<numpy.lib.npyio.NpzFile object at 0x000001B37C0D8C10>)
---------------------------------------------------------------------------
UnicodeDecodeError                        Traceback (most recent call last)
~\Anaconda3\envs\PyTorch\lib\site-packages\numpy\lib\format.py in read_array(fp, allow_pickle, pickle_kwargs)
    743         try:
--> 744             array = pickle.load(fp, **pickle_kwargs)
    745         except UnicodeError as err:

UnicodeDecodeError: 'ascii' codec can't decode byte 0xc9 in position 9: ordinal not in range(128)

During handling of the above exception, another exception occurred:

UnicodeError                              Traceback (most recent call last)
<ipython-input-4-c9307969a90f> in <module>
      6 latents_values = dataset_zip['latents_values']
      7 latents_classes = dataset_zip['latents_classes']
----> 8 metadata = dataset_zip['metadata'][()]
      9 
     10 print('Metadata: \n', metadata)

~\Anaconda3\envs\PyTorch\lib\site-packages\numpy\lib\npyio.py in __getitem__(self, key)
    258             if magic == format.MAGIC_PREFIX:
    259                 bytes = self.zip.open(key)
--> 260                 return format.read_array(bytes,
    261                                          allow_pickle=self.allow_pickle,
    262                                          pickle_kwargs=self.pickle_kwargs)

~\Anaconda3\envs\PyTorch\lib\site-packages\numpy\lib\format.py in read_array(fp, allow_pickle, pickle_kwargs)
    746             if sys.version_info[0] >= 3:
    747                 # Friendlier error message
--> 748                 raise UnicodeError("Unpickling a python object failed: %r\n"
    749                                    "You may need to pass the encoding= option "
    750                                    "to numpy.load" % (err,))

UnicodeError: Unpickling a python object failed: UnicodeDecodeError('ascii', b'\x00\x00\x00\x00\x00\x00\x00\x00\x1a\xc9\xc1\x1d*\x9f\xc4?\x1a\xc9\xc1\x1d*\x9f\xd4?\xa7\xad\xa2,\xbf\xee\xde?\x1a\xc9\xc1\x1d*\x9f\xe4?a;2\xa5\xf4\xc6\xe9?\xa7\xad\xa2,\xbf\xee\xee?\xf7\x8f\t\xdaD\x0b\xf2?\x1a\xc9\xc1\x1d*\x9f\xf4?>\x02za\x0f3\xf7?a;2\xa5\xf4\xc6\xf9?\x83t\xea\xe8\xd9Z\xfc?\xa7\xad\xa2,\xbf\xee\xfe?fs-8R\xc1\x00@\xf7\x8f\t\xdaD\x0b\x02@\x88\xac\xe5{7U\x03@\x1a\xc9\xc1\x1d*\x9f\x04@\xac\xe5\x9d\xbf\x1c\xe9\x05@>\x02za\x0f3\x07@\xcf\x1eV\x03\x02}\x08@a;2\xa5\xf4\xc6\t@\xf3W\x0eG\xe7\x10\x0b@\x83t\xea\xe8\xd9Z\x0c@\x15\x91\xc6\x8a\xcc\xa4\r@\xa7\xad\xa2,\xbf\xee\x0e@\x1de?\xe7X\x1c\x10@fs-8R\xc1\x10@\xae\x81\x1b\x89Kf\x11@\xf7\x8f\t\xdaD\x0b\x12@@\x9e\xf7*>\xb0\x12@\x88\xac\xe5{7U\x13@\xd1\xba\xd3\xcc0\xfa\x13@\x1a\xc9\xc1\x1d*\x9f\x14@c\xd7\xafn#D\x15@\xac\xe5\x9d\xbf\x1c\xe9\x15@\xf5\xf3\x8b\x10\x16\x8e\x16@>\x02za\x0f3\x17@\x87\x10h\xb2\x08\xd8\x17@\xcf\x1eV\x03\x02}\x18@\x18-DT\xfb!\x19@', 9, 10, 'ordinal not in range(128)')
You may need to pass the encoding= option to numpy.load

Duplicated factor values

It seems to me that angle number 0 and angle number 39 are the same? Did you use linspace(0, 360, 40) to get the angles? That's definitely wrong. Or did you do this deliberately to get around the rotation symmetry of squares and ellipses?

hdf5 version does not contain all the fields

Although stated otherwise, the .hdf5 version of the dataset does not contain all the fields its .npy counterpart does, e.g. the metadata field is missing.
Working on the numpy file isn't convenient as it's easy to encounter memory issues because of the need to load the whole dataset at once.

The NumPy archive images are not RGB

The README states that the data is a NPZ NumPy archive with the following fields:
imgs: (737280 x 3 x 64 x 64, uint8) Images in RGB.

However, the actual shape of the archive is (737280, 64, 64)

np.load('dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz')['imgs'].shape
(737280, 64, 64)

Normal that shapes look distorted?

I'm wondering if I made a mistake when loading the dataset or if it is normal that the shapes look distorted?

Examples:

image
image
image

I'm loading these from the dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz file like this:

file_name = 'dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz'
dataset_zip = np.load(os.path.join(data_dir, 'dSprites', file_name), encoding='latin1')
images = np.reshape(dataset_zip['imgs'], (num_samples, size, size, channels))
...

And than yielding these via the dataset API.

Regarding generative factors for the dataset

How were the images created? What was the original size of the shapes (in pixels) and how did you transform them? (what algorithms specifically)

Basically what I am asking is that what are the true generative factors of the dataset?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.