Coder Social home page Coder Social logo

amtnet's Introduction

AMTNet: action-micro-tube-network

Action Micro Tube Network (AMTNet) - Pytorch with linear heads

An implementation of AMTNet

The training and evaluation code for AMTNet is completely in PyTorch. We build on Pytorch implementation of our previous work (released here ROAD)

Original SSD implementation was adopted from Max deGroot, Ellis Brown 's implementation. Now we use linear classification and regression heads instead of convolutional heads, because we needed that chnage for other wotk TraMNet. Efficency was linear heads are the same but there is a slight increase in GPU memory consumption.

Table of Contents

Installation

  • Install PyTorch(version v1.0 as of on March 2019) by selecting your environment on the website and running the appropriate command.
  • Please install cv2 and visdom form conda-forge.
  • I recommend using anaconda 3.7.
  • You will also need Matlab. If you have distributed computing license then it would be faster otherwise it should also be fine. Just replace parfor with simple for in Matlab scripts. I would be happy to accept a PR for python version of this part.
  • Clone this repository.
    • Note: We currently only support Python 3.7 with Pytorch version v1.0 on Linux system.
  • We currently support UCF24 with revised annotaions released with our real-time online action detection paper. Unlike ROAD implementation, we support JHMDB21 as well.
  • Similar to ROAD setup, to simulate the same training and evaluation setup we provide extracted rgb images from videos along with optical flow images (both brox flow and real-time flow) computed for the UCF24 and JHMDB21 datasets. You can download it from my google drive link)
  • Install opencv package for anaconda using conda install opencv
  • We also support Visdom for visualization of loss and frame-meanAP on validation subset during training.
    • To use Visdom in the browser:
    # First install Python server and client 
    conda install -c conda-forge visdom
    # Start the server (probably in a screen or tmux)
    python -m visdom.server --port=8097
    • Then (during training) navigate to http://localhost:8097/ (see the Training section below for more details).

Dataset

To make things easy, we provide extracted rgb images from videos along with optical flow images (both brox flow and real-time flow) computed for UCF24 and JHMDB21 datasets, you can download it from my google drive link. Please download it and extract it wherever you going to store your experiments.

ActionDetection is a dataset loader Class in data/dataset.py that inherits torch.utils.data.Dataset making it fully compatible with the torchvision.datasets API.

Training AMTNet

  • Similar to ROAD, we requires VGG-16 weights pretrained on UCF24 using ROAD implmentation.
  • Weight of pretrained SSD used in ROAD can be dowloaded from HERE. These weights are exactly the same to those produced by SSD used in ROAD. This to reduce training time. We can achived results with imagenet pretrained models as well, but with different hyper parameter, I haven't kept the track of those hyperparameters.
  • If you want you can train for these weights using ROAD
  • Training of a single stream AMTnet can be achived on single 1080Ti GPU. It takes around 8GB memory. Given pretrained weight initilisation.
  • By default, we assume that you have downloaded the datasets and weights.
  • To train AMTNet using the training script simply specify the parameters listed in train.py as a flag or manually change them in script.

Let's assume that you extracted dataset in /home/user/data/ucf24/ directory, and weight in /home/user/data/weights/. Now, your train command from the root directory of this repo is going to be:

RGB frames as input

python train.py --seq_len=2 --num_workers=4 --batch_size=8 --ngpu=2 --fusion_type=NONE --input_type_base=rgb --input_frames_base=1 --lr=0.0005 --max_iter=70000 --stepvalues=50000 --val_step=10000

Brox-flow as input

python train.py --seq_len=2 --num_workers=4 --batch_size=8 --ngpu=2 --fusion_type=NONE --input_type_base=brox --input_frames_base=5 --lr=0.0005 --max_iter=70000 --stepvalues=50000 --val_step=10000

Fusion

  • copy the best model trained from above training commands to /home/user/data/weights/
  • OR download the above pretrained models to above directory.
python train.py --seq_len=2 --num_workers=4 --batch_size=8 --ngpu=2 --fusion_type=SUM --input_type_base=rgb --input_type_extra=brox --input_frames_base=1 --input_frames_extra=5 --lr=0.0005 --max_iter=70000 --stepvalues=50000 --val_step=10000

Fusion notes

  • Here, we need 2 GPUs or 16GB VRAM, or reduce the batch size to 6 or 4 and learning rate to 0.0001. Not gurrented to reproduce same results but it will be close enough.
  • You can use --fusion_type=CAT for concatnation fusion. Sum Fusion requires little less GPU memory.

Different parameters in train-ucf24.py will result in different performance

  • Other notes:
    • Single -stream network occupies almost 8GB VRAM on a GPU, we used 1080Ti for training and normal training takes about 20 hrs, you can use 1080 as well
    • For instructions on Visdom usage/installation, see the Installation section. By default, it is off.
    • If you don't like to use visdom then you always keep track of train using logfile which is saved under save_root directory
    • During training checkpoint is saved every 10K iteration also log it's frame-level frame-mean-ap on a subset of 15k test images.
    • We recommend training for 60K iterations for all the input types.

More instructions to Follow

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.