Coder Social home page Coder Social logo

rintarooo / tsp_drl_ptrnet Goto Github PK

View Code? Open in Web Editor NEW
156.0 3.0 35.0 5.78 MB

"Neural Combinatorial Optimization with Reinforcement Learning"[Bello+, 2016], Traveling Salesman Problem solver

License: MIT License

Python 93.93% Dockerfile 4.53% Shell 1.53%
tsp deep-reinforcement-learning actor-critic pointer-networks active-search pytorch

tsp_drl_ptrnet's Introduction

TSP Solver with Deep RL

This is PyTorch implementation of NEURAL COMBINATORIAL OPTIMIZATION WITH REINFORCEMENT LEARNING, Bello et al. 2016 [https://arxiv.org/abs/1611.09940]

Pointer Networks is the model architecture proposed by Vinyals et al. 2015 [https://arxiv.org/abs/1506.03134]

This model uses attention mechanism to output a permutation of the input index.

Screen Shot 2021-02-25 at 12 45 34 AM



In this work, we will tackle Traveling Salesman Problem(TSP), which is one of the combinatorial optimization problems known as NP-hard. TSP seeks for the shortest tour for a salesman to visit each city exactly once.

Training without supervised solution

In the training phase, this TSP solver will optimize 2 different types of Pointer Networks, Actor and Critic model.

Given a graph of cities where the cities are the nodes, critic model predicts expected tour length, which is generally called state-value. Parameters of critic model are optimized as the estimated tour length catches up with the actual length calculated from the tour(city permutation) predicted by actor model. Actor model updates its policy parameters with the value called advantage which subtracts state-value from the actual tour length.

Actor-Critic

Actor:  Defines the agent's behavior, its policy
Critic: Estimates the state-value 



Inference

Active Search and Sampling

In this paper, two approaches to find the best tour at inference time are proposed, which we refer to as Sampling and Active Search.

Search strategy called Active Search takes actor model and use policy gradient for updating its parameters to find the shortest tour. Sampling simply just select the shortest tour out of 1 batch.

Figure_1

Usage

Training

First generate the pickle file contaning hyperparameter values by running the following command

(in this example, train mode, batch size 512, 20 city nodes, 13000 steps).

python config.py -m train -b 512 -t 20 -s 13000

-m train could be replaced with -m train_emv. emv is the abbreviation of 'Exponential Moving Average', which doesn't need critic model. Then, go on training.

python train.py -p Pkl/train20.pkl



Inference

If training is done, set the configuration for inference.
Now, you can see how the training process went from the csv files in the Csv dir.
You may use my pre-trained weight Pt/train20_1113_12_12_step14999_act.pt which I've trained for 20 nodes'.

python config.py -m test -t 20 -s 10 -ap Pt/train20_1113_12_12_step14999_act.pt --islogger --seed 123
python test.py -p Pkl/test20.pkl



Environment

I leave my own environment below. I tested it out on a single GPU.

  • OS:
    • Linux(Ubuntu 18.04.5 LTS)
  • GPU:
    • NVIDIA® GeForce® RTX 2080 Ti VENTUS 11GB OC
  • CPU:
    • Intel® Xeon® CPU E5640 @ 2.67GHz
  • NVIDIA® Driver = 455.45.01
  • Docker = 20.10.3
  • nvidia-docker2(for GPU)

Dependencies

  • Python = 3.6.10
  • PyTorch = 1.6.0
  • numpy
  • tqdm (if you need)
  • matplotlib (only for plotting)

Docker(option)

Make sure you've already installed Docker

docker version

latest NVIDIA® Driver

nvidia-smi

and nvidia-docker2(for GPU)

Usage

  1. build or pull docker image

build image(this might take some time)

./docker.sh build

pull image from dockerhub

docker pull docker4rintarooo/tspdrl:latest
  1. run container using docker image(-v option is to mount directory)
./docker.sh run

If you don't have a GPU, you can run

./docker.sh run_cpu



Reference

tsp_drl_ptrnet's People

Contributors

rintarooo avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

tsp_drl_ptrnet's Issues

possible error in `critic.py` `forward()`?

Very helpful repo!
One question, in the forward function in critic.py, there might possibly be an error:
In line 37, the Decoder always takes in the same initial dec_input for each city in the sequence, while it should actually take in the output from the last city? Like in actor.py the dec_input is updated after processing each city.
Thanks in advance and looking forward to your reply!


updates below:
actually I think I got messed up. Now my understanding is that for the actor, the dec_input should be the embedding of the sampled action according to the probability output of the corresponding time step, instead of the the updated weighted sum of ref as it is currently done in actor.py. But I'm then very confused as how this should be done in critic.py, should it sample seperately than actor?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.