"On Decoding", or, how to learn the model $ \log p(y | x ) = \sum_{j=0}^{m-1} \log p(y_{j+1} | y_{\le_{j}}, x)$ and estimate $y^{} = \arg\max_{y} p(y|x)$, using a transformer architecture. (Note that the sum is over terms of the form $\log p(y_{j+1} | y_{\le_{j}}, x)$ as this is the meaning of being autoregressive--to generative the locally next token, we condition on all previously generated tokens.) \ This estimation is exponential in the length of the decoded string $y$. (To be precise it is $y^{|V|}$ where $V$ is the vocabulary size).\ The following methods reduce the search space of large language model inference while approximating the optimal $y^{}$.\
ai_planning_searching's Introduction
ai_planning_searching's People
ai_planning_searching's Issues
select function should populate the mcts_tree. modify test_select to check that children nodes are appended.
see title
Fix segmentation fault due to load_dataset, optionally bypass the huggingface dataloader
Read the APPS paper and import the dataset into codebase
- Read the APPS paper: https://arxiv.org/pdf/2105.09938
- Import the APPS dataset
[p1] Refactor code to allow for passing in the question as the prompt to MCTS (main_algorithm) and the accompanying unit tests as the reward function (evaluate_full_paths)
[p1] Build a dataloader that loads the APPS dataset question (prompt) with its associated unit tests, which allows shuffling, and multiple epoch loading, et cetera
fix test_expand
- Code is getting unreadable. Create a Beam_Item class to keep track of candidate beams, in particular with a field beam_tokens to keep track of the tokens as a single Tensor (1, seq_len_so_far)
AND
- After each call to logits_to_token_strings, make sure beam_tokens is updated with the newly generated token
[p0] Set attention mask and pad token in beam generate
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's attention_mask
to obtain reliable results.
Setting pad_token_id
to eos_token_id
:50256 for open-end generation.
[p1] The plots! The plots! Do a comparison with greedy decode and beam decode in a jupyter notebook
The goal is to show that MCTS on the apps test dataset (without finetuning on the apps training set) gets higher return than beam decode and greedy decode.
fix backpropagate_statistics
line 252 in backpropagate_statistics needs to update P_UCB_s_a
In particular, what is the representation of P_s_a we are storing and does it accord with the MCTS algorithm as intended?
Add option to use some other function f to score trajectories, as currently the MCTS code just implements beam search
code cleanup
decide on some standard for the format of executing python inputs
ie using locals like
program = """def exponentiate(x):
return x**2"""
res = exec(program)
or f-string like
"""
x = {unit_test_input}
def exponentiate(x):
return x**2
"""
[p0] apps_dataset playground
- Make a jupyter notebook for playing with the APPS dataset
- Write a function (call it verify_exec_smoothness) to sample questions from the dataset (q_statement), sample a ground truth solution (sol_gt), sample from the input space of each question (q_input), feed the sampled input (q_input) to the sampled ground truth (sol_gt) using the exec function, (currently the notebook requires keyboard input... if this is too difficult, then try running it as a python script). This function should verify that the test dataset is executable using exec() [or some modification of it] almost everywhere in the dataset
Add test to ensure that Node.current_token indeed only represents the current token in the sequence and not the entire token
refactor select function
Currently the select function does not do what it says in the docstring, i.e. "traverse to the tree to find a node that has not been previously expanded (a node without children nodes)."
Instead it contains too many other functions folded in, such as calling expand.
Then, after this refactor, rewrite test_select.
[p1] rewrite reward_function_utils.py
Todo after #16
node 0 does not contain the correct max_rollout_reward value for Q_s_a value for child action '3' in test_backpropagate_statistics
see title
Made code compatible with using multi-token prompt and doing MCTS one token at a time, i.e. a minor modification to the test in test_main_algorithm
[p2] try some form of parallelization
Monte Carlo tree search can be concurrently executed by many threads or processes. There are several fundamentally different methods of its parallel execution:[52]
Leaf parallelization, i.e. parallel execution of many playouts from one leaf of the game tree.
Root parallelization, i.e. building independent game trees in parallel and making the move basing on the root-level branches of all these trees.
Tree parallelization, i.e. parallel building of the same game tree, protecting data from simultaneous writes either with one, global mutex, with more mutexes, or with non-blocking synchronization.[53]
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. ๐๐๐
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google โค๏ธ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.