Coder Social home page Coder Social logo

scottblack1998 / energy-based-scene-graph Goto Github PK

View Code? Open in Web Editor NEW

This project forked from mods333/energy-based-scene-graph

0.0 0.0 0.0 26.41 MB

Code release for Energy-Based Learning for Scene Graph Genertaion

License: Other

C++ 0.14% Python 7.26% C 0.07% Cuda 0.88% Jupyter Notebook 91.62% Dockerfile 0.03%

energy-based-scene-graph's Introduction

PWC

Energy-Based Learning for Scene Graph Generation

This repository contains the code for our paper Energy-Based Learning for Scene Graph Generation accepted at CVPR 2021.

Envirioment setup

To setup the environment with all the required dependancies follow the steps in Install.md.
Note: By default the cudatoolkit version is set to 10.0. When creating an environment on your machine check you cuda compiler version by running nvcc --version and adjust the cudatoolkit version appopriately. Version mismatches can lead to the build failing or segmentaion fault error when running the code.

DATASET

Check Dataset.md for details on downloading the datasets.

Pre-Trained Models

We realsed the weights for the pretained VCTree model on the Visual Genome dataset trained using both cross-entropy based and energy-based training.

EBM CE
VCTree-Predcls VCTree-PredCLS
VCTree-SGCLS VCTree-SGCLS
VCTree-SGDET VCTree-SGDET

To train you own models you can obtain the weights for the pretrained detectron from this repository.

Training for Energy Based Scene Graph Generation

python -m torch.distributed.launch --master_port 10001 --nproc_per_node=4 \
    tools/energy_joint_train_cd.py --config-file configs/e2e_relation_X_101_32_8_FPN_1x.yaml \
    MODEL.ROI_RELATION_HEAD.USE_GT_BOX True MODEL.ROI_RELATION_HEAD.USE_GT_OBJECT_LABEL True \
    MODEL.ROI_RELATION_HEAD.PREDICTOR VCTreePredictor \
    SOLVER.IMS_PER_BATCH 16  TEST.IMS_PER_BATCH 4 \
    DTYPE float16 SOLVER.MAX_ITER 50000 SOLVER.VAL_PERIOD 2000 \
    SOLVER.CHECKPOINT_PERIOD 2000 \
    GLOVE_DIR $GLOVE_DIR \
    MODEL.PRETRAINED_DETECTOR_CKPT $PRETRAINED_DETECTOR_PATH \
    OUTPUT_DIR $OUTPUT_DIR \
    SOLVER.BASE_LR 0.001 SAMPLER.LR 1.0 SAMPLER.ITERS 20 SAMPLER.VAR 0.001 SAMPLER.GRAD_CLIP 0.01 MODEL.DEV_RUN False

The above scripts trains a model using 4 GPUs. Here how to change the training behavior for various requirements.

  1. Scene Graph Genration Tasks
    1. For PredCLS set
      MODEL.ROI_RELATION_HEAD.USE_GT_BOX True MODEL.ROI_RELATION_HEAD.USE_GT_OBJECT_LABEL True
    2. For SGCLS set
      MODEL.ROI_RELATION_HEAD.USE_GT_BOX True MODEL.ROI_RELATION_HEAD.USE_GT_OBJECT_LABEL False
    3. For SGDet set
      MODEL.ROI_RELATION_HEAD.USE_GT_BOX False MODEL.ROI_RELATION_HEAD.USE_GT_OBJECT_LABEL False
  2. Changing scene graph prediction model
    Change MODEL.ROI_RELATION_HEAD.PREDICTOR to one of the available models
  3. Modifying Sampler
    Current implementation only has a single sampler (SGLD). You can implement samplers of your choice in maskrcnn_benchmark/modeling/energy_head/sampler.py. To change the parametes of the sampler use the fields under SAMPLER in the config.

Acknowledgment

This repository is developed on top of the scene graph benchmarking framwork develped by KaihuaTang

energy-based-scene-graph's People

Contributors

mods333 avatar kaihuatang avatar skyil7 avatar karim-53 avatar navidre avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.