Coder Social home page Coder Social logo

yuvalailer / nnplot Goto Github PK

View Code? Open in Web Editor NEW
9.0 3.0 1.0 10.28 MB

:tv: A Python library for pruning and visualizing Keras Neural Networks' structure and weights

License: MIT License

Python 100.00%
neural-network neural-networks neural-network-architectures neural-network-python visualization visualization-tools weights-visualization pruning keras-visualization keras-neural-networks

nnplot's Introduction

nnplot

📺 A Python library for pruning and visualizing Keras Neural Networks' structure and weights.

_____________________________ cover gif _____________________________

cover image

nnplot is a Python library for visualizing Neural Networks in an informative representation.

It has the ability to display the NN's structure, dominance of links weights per layer and their polarity (positive/negative).

It also provides functions for pruning the NN in order to display the n “most important” nodes of each layer.

It's simple:

Install:

pip install nnplot graphviz

And in your code:

from nnplot.functions import plot_net

plot_net(trained_model)

and your network will be plotted!

See more:

How to install it?

nnplot is available via pip:

pip install nnplot graphviz

*alternatively, you can download the nnplot subfolder and place it in the same directory of your code.

How to use it?

once you have a built and trained Keras model, you can simply:

plot the model:

from nnplot.functions import plot_net

plot_net(trained_model)

or prune the model:

from nnplot.functions import prune

[new_model, new_input_list] = prune(model, max_limit=5)

To make the most out of the functions mentioned above, try using their optional flags (examples follow):

plot:

from nnplot.functions import plot_net

plot_net(model,
         input_list    = ['1: Input A', '2: Input B', '3: Input C','etc..'],
         view=True,
         filename      = "my_network.gv",
         plot_title    = "My Neural Network Title",
         out_title     = "This is my output title: \n 1. output A \n 2. output B",
         color_edges   = "rb",
         print_weights = False,
         size_limit    = 10)

Arguments:

model: A Keras model instance.

view: whether to plot the model on screen after its generation.

filename: path and name to save the visualization outcome, as a PDF and a .gv (graph-viz) file.

title: A title for the graph.

color_edges: whether to visualize the weights of the edges as colors.

options:

  • "rb" - Red / Black: red for positive edges and black for negative ones.
  • "mc" - Multi Colored: all edges that converge into the same node, have the same (unique) color.
  • "none" - all edges painted black (but thickness visualization remains).

print_weights: whether to print the weights of the edges to the screen.

size_limit: max number of nodes in each layer (simply the first n nodes, use prune for a more complex node selection).

prune:

from nnplot.functions import prune

[new_network, new_input_list] = prune(model,
                                      max_limit,
                                      input_list = ['1: Input A', '2: Input B','etc..'],
                                      verbose    = True)
				

Arguments:

model: A Keras model instance.

max_limit: maximal number of nodes on each layer. (How are the nodes picked in this prune?)

input_list: list of input names, so that the new input_indexes output will have their original names.

verbose: print information of the process along the way.

Outputs:

network: the new pruned network.

input_indexs: the indexes of the chosen inputs

nnplot's People

Contributors

yuval-ai avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Forkers

ai-transparent

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.