Coder Social home page Coder Social logo

symjax's Introduction

SymJAX logo

SymJAX: symbolic CPU/GPU/TPU programming Continuous integration doctest license Code style: black

This is an under-development research project, not an official product, expect bugs and sharp edges; please help by trying it out, reporting bugs. Reference docs

Announcement: First major stable release expected by end of June.

What is SymJAX ?

SymJAX is a symbolic programming version of JAX simplifying graph input/output/updates and providing additional functionalities for general machine learning and deep learning applications. From an user perspective SymJAX apparents to Theano with fast graph optimization/compilation and broad hardware support, along with Lasagne-like deep learning functionalities

Why SymJAX ?

The number of libraries topping Jax/Tensorflow/Torch is large and growing by the day. What SymJAX offers as opposed to most is an all-in-one library with diverse functionalities such as

  • dozens of various datasets with clear descriptions and one line import
  • advanced signal processing tools such as multiple wavelet familites (in time and frequency domain), multiple time-frequency representations, apodization windows, ...
  • IO utilities to monitor/save/track specific statistics during graph execution through h5 files and numpy, simple and explicit graph saving allowing to save and load models without burden
  • side utilities such as automatic batching of dataset, data splitting, cross-validation, ...

and most importantly, a SYMBOLIC/DECLARATIVE programming environment allowing CONCISE/EXPLICIT/OPTIMIZED computations across devices.

For imperative programming using Jax see for example FLAX, and in general, Tensorflow, PyTorch

Examples

import sys
import symjax as sj
import symjax.tensor as T

# create our variable to be optimized
mu = T.Variable(T.random.normal((), seed=1))

# create our cost
cost = T.exp(-(mu-1)**2)

# get the gradient, notice that it is itself a tensor that can then
# be manipulated as well
g = sj.gradients(cost, mu)
print(g)

# (Tensor: shape=(), dtype=float32)

# create the compield function that will compute the cost and apply
# the update onto the variable
f = sj.function(outputs=cost, updates={mu:mu-0.2*g})

for i in range(10):
    print(f())

# 0.008471076
# 0.008201109
# 0.007946267
# ...

Installation

Make sure to install all the needed GPU drivers (for GPU support, not mandatory) and install JAX as follows (see guide):

# install jaxlib
PYTHON_VERSION=cp37  # alternatives: cp35, cp36, cp37, cp38
CUDA_VERSION=cuda92  # alternatives: cuda92, cuda100, cuda101, cuda102
PLATFORM=linux_x86_64  # alternatives: linux_x86_64
BASE_URL='https://storage.googleapis.com/jax-releases'
pip install --upgrade $BASE_URL/$CUDA_VERSION/jaxlib-0.1.39-$PYTHON_VERSION-none-$PLATFORM.whl

pip install --upgrade jax  # install jax

Then simply install SymJAX as follows:

pip install symjax

once this is done, to leverage the dataset please set up the environment variable

export DATASET_PATH=/path/to/default/location/

this path will be used as the default path where to download the various datasets in case no explicit path is given. Additionally, the following options are standard to be set up to link with the CUDA library and deactivate the memory preallocation (example below for CUDA10.1, change for desired version)

export CUDA_DIR="/usr/local/cuda-10.1"
export LD_LIBRARY_PATH=$CUDA_DIR/lib64:$LD_LIBRARY_PATH
export LIBRARY_PATH=$CUDA_DIR/lib64:$LIBRARY_PATH
export XLA_PYTHON_CLIENT_PREALLOCATE='false'
export XLA_FLAGS="--xla_gpu_cuda_data_dir=$CUDA_DIR"

symjax's People

Contributors

randallbalestriero avatar leonard-seydoux avatar brandonwillard avatar joaogui1 avatar koldh avatar dependabot[bot] avatar

Watchers

James Cloos avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.