Coder Social home page Coder Social logo

michalwols / yann Goto Github PK

View Code? Open in Web Editor NEW
24.0 3.0 5.0 1.31 MB

Yet Another Neural Network Library ๐Ÿค”

Home Page: https://michalwols.github.io/yann/

License: MIT License

Python 99.80% Shell 0.13% Dockerfile 0.06%
pytorch deep-learning neural-network python nn torch

yann's Introduction

yann (Yet Another Neural Network Library)

Yann is an extended version of torch.nn, adding a ton of sugar to make training models as fast and easy as possible.

Getting Started

Install

pip install yann

Train LeNet on MNIST

import torch
from torch import nn
from torchvision import transforms

import yann
from yann.train import Trainer
from yann.modules import Stack, Flatten, Infer
from yann.params import HyperParams, Choice, Range


class Params(HyperParams):
  dataset = 'MNIST'
  batch_size = 32
  epochs = 10
  optimizer: Choice(('SGD', 'Adam')) = 'SGD'
  learning_rate: Range(.01, .0001) = .01
  momentum = 0

  seed = 1

# parse command line arguments
params = Params.from_command()

# set random, numpy and pytorch seeds in one call
yann.seed(params.seed)

lenet = Stack(
  Infer(nn.Conv2d, 10, kernel_size=5),
  nn.MaxPool2d(2),
  nn.ReLU(inplace=True),
  Infer(nn.Conv2d, 20, kernel_size=5),
  nn.MaxPool2d(2),
  nn.ReLU(inplace=True),
  Flatten(),
  Infer(nn.Linear, 50),
  nn.ReLU(inplace=True),
  Infer(nn.Linear, 10),
  activation=nn.LogSoftmax(dim=1)
)

# run a forward pass to infer input shapes using `Infer` modules
lenet(torch.rand(1, 1, 28, 28))

# use the registry to resolve optimizer name to an optimizer class
optimizer = yann.resolve.optimizer(
  params.optimizer,
  yann.trainable(lenet.parameters()),
  momentum=params.momentum,
  lr=params.learning_rate
)

train = Trainer(
  model=lenet,
  optimizer=optimizer,
  dataset=params.dataset,
  batch_size=params.batch_size,
  transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
  ]),
  loss='nll_loss',
  metrics=('accuracy',)
)

train(params.epochs)

# save checkpoint
train.checkpoint()

# plot the loss curve
train.history.plot()

view the generated cli help

python train.py -h
-h
usage: train_mnist.py [-h] [-o {SGD,Adam}] [-lr LEARNING_RATE] [-d DATASET]
                      [-bs BATCH_SIZE] [-e EPOCHS] [-m MOMENTUM] [-s SEED]

optional arguments:
  -h, --help            show this help message and exit
  -o {SGD,Adam}, --optimizer {SGD,Adam}
                        optimizer (default: SGD)
  -lr LEARNING_RATE, --learning_rate LEARNING_RATE
                        learning_rate (default: 0.01)
  -d DATASET, --dataset DATASET
                        dataset (default: MNIST)
  -bs BATCH_SIZE, --batch_size BATCH_SIZE
                        batch_size (default: 32)
  -e EPOCHS, --epochs EPOCHS
                        epochs (default: 10)
  -m MOMENTUM, --momentum MOMENTUM
                        momentum (default: 0)
  -s SEED, --seed SEED  seed (default: 1)

then start a training run

python train.py -bs=16

which should print the following to stdout

Params(
  optimizer=SGD,
  learning_rate=0.01,
  dataset=MNIST,
  batch_size=16,
  epochs=10,
  momentum=0,
  seed=1
)
Starting training

name: MNIST-Stack
root: train-runs/MNIST-Stack/19-09-25T18:02:52
batch_size: 16
device: cpu

MODEL
=====

Stack(
  (infer0): Infer(
    (module): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
  )
  (max_pool2d0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (re_lu0): ReLU(inplace=True)
  (infer1): Infer(
    (module): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
  )
  (max_pool2d1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (re_lu1): ReLU(inplace=True)
  (flatten0): Flatten()
  (infer2): Infer(
    (module): Linear(in_features=320, out_features=50, bias=True)
  )
  (re_lu2): ReLU(inplace=True)
  (infer3): Infer(
    (module): Linear(in_features=50, out_features=10, bias=True)
  )
  (activation): LogSoftmax()
)


DATASET
=======

TransformDataset(
Dataset: Dataset MNIST
    Number of datapoints: 60000
    Root location: /Users/michal/.torch/datasets/MNIST
    Split: Train
Transforms: (Compose(
    ToTensor()
    Normalize(mean=(0.1307,), std=(0.3081,))
),)
)


LOADER
======

<torch.utils.data.dataloader.DataLoader object at 0x1a45cc8940>

LOSS
====

<function nll_loss at 0x120b700d0>


OPTIMIZER
=========

SGD (
Parameter Group 0
    dampening: 0
    lr: 0.01
    momentum: 0
    nesterov: False
    weight_decay: 0.0001
)

SCHEDULER
=========

None


PROGRESS
========
epochs: 0
steps: 0
samples: 0


Starting epoch 0

OPTIMIZER
=========

SGD (
Parameter Group 0
    dampening: 0
    lr: 0.01
    momentum: 0
    nesterov: False
    weight_decay: 0.0001
)


PROGRESS
========
epochs: 0
steps: 0
samples: 0


Batch inputs shape: (16, 1, 28, 28)
Batch targets shape: (16,)
Batch outputs shape: (16, 10)

batch:        0	accuracy: 0.1875	loss: 2.3783
batch:      128	accuracy: 0.6250	loss: 2.0528
batch:      256	accuracy: 0.6875	loss: 0.6222

yann's People

Contributors

michalwols avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

yann's Issues

CLI improvements

Train

Allow queuing experiments with ranked priority, potentially handle running on a remote machine

yann train -m resnet50 -d ImageNet -bs 32 --distributed --priority=2

Scaffold Project

use cookiecutters to define common project types

yann scaffold project-name  --type/template image-recognition
  .github/actions
  data/
    raw/
    processed/
  train-runs/
  notebooks/
  docs/
  tests/
  {{project_name}}/
    models/
    datasets/
    cli.py
    train.py
    evaluate.py
    serve.py
  requirements.txt
  conda.yml
  setup.py
  dockerfile
  run
  README.md


  run
    prepare-data()
    test()
    train()
    evaluate()
    install-dependencies()
    save-dependencies()
    demo()
    deploy()

Server Models

support runtimes like onnx-runtime, tensorrt, jit traced

yann serve ./checkpoint-path

yann serve model-name

Compare Train Runs

yann compare ./train-runs/MNIST-*

Evaluate model on given data

run inference

yann evaluate ./model.th cifar10 ./output.parquet

evaluate model

yann validate ./model.th cifar10 ./output.parquet

Profile script

yann profile/benchmark

pytorch profiler with chrome flamegraph

https://github.com/netdata/netdata
https://github.com/nicolargo/glances

  • gpu utilization
  • cpu
  • memory
  • network
  • disk io

Export models for Inference

yann export

jit torch script, quantize, onnx, etc

Convert data

yann convert data.csv data.parquet

Train run queuing

Use dask distributed or huey as a simple queue for queueing experiments.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.