Coder Social home page Coder Social logo

rmanluo / reasoning-on-graphs Goto Github PK

View Code? Open in Web Editor NEW
297.0 7.0 32.0 16.79 MB

Official Implementation of ICLR 2024 paper: "Reasoning on Graphs: Faithful and Interpretable Large Language Model Reasoning"

Home Page: https://arxiv.org/abs/2310.01061

License: MIT License

Shell 3.76% Python 96.24%
kg knowledge large-language-models llm reasoning reasoning-on-graph

reasoning-on-graphs's Introduction

Reasoning on Graphs (RoG)

Official Implementation of "Reasoning on Graphs: Faithful and Interpretable Large Language Model Reasoning".

Reasoning on graphs (RoG) synergizes LLMs with KGs to enable faithful and interpretable reasoning. We present a planning-retrieval-reasoning framework, where RoG first generates relation paths grounded by KGs as faithful plans. These plans are then used to retrieve valid reasoning paths from the KGs for LLMs to conduct faithful reasoning and generate interpretable results.

Requirements

pip install -r requirements.txt

Pre-trained weights

Our code will automatically download the model weight from the huggingface.

You can find the pre-trained weights here.

Datasets

Our code will automatically download the data from the huggingface.

RoG-WebQSP
RoG-CWQ

Subgraph Extraction

We extract the subgraphs from the Freebase following previous studies. The code can be found here.

Inference

Requirements: Any GPU with at least 12GB memory.

Step1: Planning (Generate relation paths)

Run: ./scripts/planning.sh

python src/qa_prediction/gen_rule_path.py \
        --model_name RoG \
        --model_path rmanluo/RoG \
        -d {RoG-webqsp,RoG-cwq} \
        --split test \
        --n_beam 3

Generated rules will be saved at: results/gen_rule_path/{dataset}/{model_name}/{split}

Step2: Reasoning (Generate answers with RoG)

Run: ./scripts/rog-reasoning.sh

python src/qa_prediction/predict_answer.py \
        --model_name RoG \
        --model_path rmanluo/RoG \
        -d {RoG-webqsp,RoG-cwq} \
        --prompt_path prompts/llama2_predict.txt \
        --add_rul \
        --rule_path {rule_path} \

Answers will be saved at: results/KGQA/{dataset}/{model_name}/{split}

Plug-and-play Reasoning (Generate answers with different LLMs)

Note: you need to set your openai key at .env to use ChatGPT.

Run: ./scripts/plug-and-play.sh

python src/qa_prediction/predict_answer.py \
        --model_name {gpt-3.5-turbo,alpaca,llama2-chat-hf,flan-t5} \
        -d {RoG-webqsp,RoG-cwq} \
        --prompt_path {prompt_path} \
        --add_rule \
        --rule_path {rule_path}

Interpretable Reasoning

Run: python scripts/interpretable_example.py

from transformers import pipeline, AutoTokenizer
import torch

MODEL_PATH_OR_NAME="rmanluo/RoG"

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH_OR_NAME, use_fast=False)
model = pipeline("text-generation", model=MODEL_PATH_OR_NAME, tokenizer=tokenizer, device_map="auto", torch_dtype=torch.float16)

print("====EXAMPLE 1: ====")

INPUT_TEXT_1 = """Based on the reasoning paths, please answer the given question and explain why 

Reasoning Paths: 
Northern District -> location.administrative_division.first_level_division_of -> Israel -> government.form_of_government.countries -> Parliamentary system

Question: 
What type of government is used in the country with Northern District?"""

outputs = model(INPUT_TEXT_1, return_full_text=False)
print(outputs[0]['generated_text'])

Training

Training Datasets

You can download the processed datasets from RoG_train_data.tar.tz. Unzip the files and put them under datasets/ folder.

Process datasets
  1. Build question to relation path pairs.
python src/align_kg/build_align_qa_dataset.py -d {RoG-webqsp,RoG-cwq} --split {train,validation,test}
  1. Build joint-training datasets.
python src/joint_training/preprocess_align.py
python src/joint_training/preprocess_qa.py
  1. Build interpretable examples.
python src/joint_training/generate_explanation_results.py

Training RoG

2 A100-80GB GPUs are required for training RoG.

Run: ./scripts/train.sh

Results

Bibinfo

If you found this repo helpful, please help us by citing this paper:

@inproceedings{luo2024rog,
title={Reasoning on Graphs: Faithful and Interpretable Large Language Model Reasoning},
author={Luo, Linhao and Li, Yuan-Fang and Haffari, Gholamreza and Pan, Shirui},
booktitle={International Conference on Learning Representations},
  year={2024}
}

reasoning-on-graphs's People

Contributors

rmanbio avatar rmanluo avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

reasoning-on-graphs's Issues

Where are the Planning (e.g., KL-divergence) and Retrieval-reason optimization implemented?

Thanks for sharing this work!

The paper mentions using KL-divergence as a loss function for planning optimization. However, I couldn't locate the code that implements this KL-divergence loss, along with the loss of retrival-reason.

Could you or someone else please point me to the relevant files or provide more information on where these components are implemented?

Out-of-Memory Issue

Hi! I am reaching out regarding the current code implementation, which does not utilize DataParallel, making it unclear how to use multiple GPUs effectively. In your paper, you mentioned using two A100 GPUs, whereas in our case, we intend to use six A6000 GPUs. However, when we run the code as it is, it only utilizes two GPUs and encounters an Out-of-Memory issue.

Could you kindly provide guidance on how to configure the code to specify the number of GPUs to be used?

Transfer to other KG

Hi there,thanks your work! I would like to know how to transfer your model to a knowledge graph in my field and create a knowledge graph based LLM question answering system.

encountered some bugs when loading dataset

Hi there,
I am a graduate student interested in your work!
I am trying to run your code. For the planning inference (getting reasoning path), when running gen_rule_path.py line 122 there is some error. details are below:

Save results to: results/gen_rule_path/RoG-cwq/RoG/test
Using custom data configuration rmanluo--RoG-cwq-a052b4ae8515a88d
Downloading and preparing dataset parquet/rmanluo--RoG-cwq to /home/v-sitaocheng/.cache/huggingface/datasets/rmanluo___parquet/rmanluo--RoG-cwq-a052b4ae8515a88d/0.0.0/7328ef7ee03eaf3f86ae40594d46a1cec86161704e02dd19f232d81eee72ade8...
Traceback (most recent call last):
File "/reasoning-on-graphs/src/qa_prediction/gen_rule_path.py", line 235, in
gen_path = gen_prediction(args)
File "/reasoning-on-graphs/src/qa_prediction/gen_rule_path.py", line 122, in gen_prediction
dataset = load_dataset(input_file, split=args.split)
File "/anaconda/envs/LLM-kbqa/lib/python3.8/site-packages/datasets/load.py", line 1679, in load_dataset
builder_instance.download_and_prepare(
File "/anaconda/envs/ccc/lib/python3.8/site-packages/datasets/builder.py", line 704, in download_and_prepare
self._download_and_prepare(
File "/anaconda/envs/ccc/lib/python3.8/site-packages/datasets/builder.py", line 793, in _download_and_prepare
self._prepare_split(split_generator, **prepare_split_kwargs)
File "/anaconda/envs/ccc/lib/python3.8/site-packages/datasets/builder.py", line 1271, in _prepare_split
writer.write_table(table)
File "/anaconda/envs/ccc/lib/python3.8/site-packages/datasets/arrow_writer.py", line 518, in write_table
self._build_writer(inferred_schema=pa_table.schema)
File "/anaconda/envs/ccc/lib/python3.8/site-packages/datasets/arrow_writer.py", line 352, in _build_writer
inferred_features = Features.from_arrow_schema(inferred_schema)
File "/anaconda/envs/ccc/lib/python3.8/site-packages/datasets/features/features.py", line 1533, in from_arrow_schema
return Features.from_dict(metadata["info"]["features"])
File "/anaconda/envs/ccc/lib/python3.8/site-packages/datasets/features/features.py", line 1562, in from_dict
obj = generate_from_dict(dic)
File "/anaconda/envs/ccc/lib/python3.8/site-packages/datasets/features/features.py", line 1263, in generate_from_dict
return {key: generate_from_dict(value) for key, value in obj.items()}
File "/anaconda/envs/ccc/lib/python3.8/site-packages/datasets/features/features.py", line 1263, in
return {key: generate_from_dict(value) for key, value in obj.items()}
File "/anaconda/envs/ccc/lib/python3.8/site-packages/datasets/features/features.py", line 1267, in generate_from_dict
return Sequence(feature=generate_from_dict(obj["feature"]), length=obj["length"])
KeyError: 'length'

I looked up on internet and found that there might be the problems on the dataset. Should I download it manually or is there anything wrong?

traing code and configuration requirements

Hi, bro,
could you upload the coding about how to training the model?
And, please add some introduction about the configuration requirements, that is, what kind of machine and card is needed, VRAM?

Where should I put the datasets?

Hi, I read your paper and it was really amazing. But I have some questions.

我下载了这两个数据库 "RoG-webqsp" and "RoG-cwq" 也下载了RoG模型的权重。

我把数据文件放在了根目录下面一个叫做rmanluo的文件夹里。

但我在运行这个文件的时候还是报错,说找不到数据文件:rog-reasoning.sh

截屏2024-03-11 15 18 07

下面是报错信息:

(reasoningongraph) user@user-4U-GPU-Server:/data/czy/projects/reasoning-on-graphs$ ./scripts/rog-reasoning.sh
########
rmanluo/RoG-webqsp
########
Traceback (most recent call last):
  File "/data/czy/projects/reasoning-on-graphs/src/qa_prediction/predict_answer.py", line 236, in <module>
    main(args, LLM)
  File "/data/czy/projects/reasoning-on-graphs/src/qa_prediction/predict_answer.py", line 99, in main
    dataset = load_dataset(input_file, split=args.split)
  File "/data/miniconda3/envs/reasoningongraph/lib/python3.10/site-packages/datasets/load.py", line 2112, in load_dataset
    builder_instance = load_dataset_builder(
  File "/data/miniconda3/envs/reasoningongraph/lib/python3.10/site-packages/datasets/load.py", line 1798, in load_dataset_builder
    dataset_module = dataset_module_factory(
  File "/data/miniconda3/envs/reasoningongraph/lib/python3.10/site-packages/datasets/load.py", line 1429, in dataset_module_factory
    ).get_module()
  File "/data/miniconda3/envs/reasoningongraph/lib/python3.10/site-packages/datasets/load.py", line 861, in get_module
    module_name, default_builder_kwargs = infer_module_for_data_files(
  File "/data/miniconda3/envs/reasoningongraph/lib/python3.10/site-packages/datasets/load.py", line 516, in infer_module_for_data_files
    raise FileNotFoundError(f"No (supported) data files or dataset script found{path}")
FileNotFoundError: No (supported) data files or dataset script found in rmanluo/RoG-webqsp. 
########
rmanluo/RoG-cwq
########
Traceback (most recent call last):
  File "/data/czy/projects/reasoning-on-graphs/src/qa_prediction/predict_answer.py", line 236, in <module>
    main(args, LLM)
  File "/data/czy/projects/reasoning-on-graphs/src/qa_prediction/predict_answer.py", line 99, in main
    dataset = load_dataset(input_file, split=args.split)
  File "/data/miniconda3/envs/reasoningongraph/lib/python3.10/site-packages/datasets/load.py", line 2112, in load_dataset
    builder_instance = load_dataset_builder(
  File "/data/miniconda3/envs/reasoningongraph/lib/python3.10/site-packages/datasets/load.py", line 1798, in load_dataset_builder
    dataset_module = dataset_module_factory(
  File "/data/miniconda3/envs/reasoningongraph/lib/python3.10/site-packages/datasets/load.py", line 1429, in dataset_module_factory
    ).get_module()
  File "/data/miniconda3/envs/reasoningongraph/lib/python3.10/site-packages/datasets/load.py", line 861, in get_module
    module_name, default_builder_kwargs = infer_module_for_data_files(
  File "/data/miniconda3/envs/reasoningongraph/lib/python3.10/site-packages/datasets/load.py", line 516, in infer_module_for_data_files
    raise FileNotFoundError(f"No (supported) data files or dataset script found{path}")
FileNotFoundError: No (supported) data files or dataset script found in rmanluo/RoG-cwq. 

能告诉我怎么处理嘛?

Where is the code related to Optimization?

Hi there, thanks for your job. But I have a little question. You mentioned that there are two kinds of optimization(Planning optimization & Retrieval-reasoning optimization.) in the process of generating relation paths in your paper. But I failed to find related code in the project. I just wonder how it is finetuned to make a llm able to generate reasonable relation paths. It's quite challenging because the llm may haven't seen many such tasks in the training process. Looking forward to your reply!

Data path while running the "generate_explanation_results.py" file.

Hello, thanks for sharing your work.

I have a question about the process of building interpretable examples. I'm currently preprocessing the dataset and encountered the following error:

(RoG) youminkk@gold:~/Paper/RoG$ python src/joint_training/preprocess_qa.py
[2024-07-16 14:53:18,110] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
/home/youminkk/miniconda3/envs/RoG/lib/python3.10/site-packages/transformers/utils/generic.py:260: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  torch.utils._pytree._register_pytree_node(
/home/youminkk/miniconda3/envs/RoG/lib/python3.10/site-packages/transformers/utils/generic.py:260: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  torch.utils._pytree._register_pytree_node(
/home/youminkk/.local/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
Creating json from Arrow format: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 27.54ba/s]
Resolving data files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 78.69it/s]
Resolving data files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 106936.93it/s]
Creating json from Arrow format: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 28.45ba/s]
(RoG) youminkk@gold:~/Paper/RoG$ python src/joint_training/generate_explanation_results.py
[2024-07-16 14:53:39,471] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
/home/youminkk/miniconda3/envs/RoG/lib/python3.10/site-packages/transformers/utils/generic.py:260: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  torch.utils._pytree._register_pytree_node(
/home/youminkk/miniconda3/envs/RoG/lib/python3.10/site-packages/transformers/utils/generic.py:260: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  torch.utils._pytree._register_pytree_node(
/home/youminkk/.local/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
Traceback (most recent call last):
  File "/home/youminkk/Paper/RoG/src/joint_training/generate_explanation_results.py", line 132, in <module>
    train_dataset = datasets.load_dataset(input_file, split="train")
  File "/home/youminkk/miniconda3/envs/RoG/lib/python3.10/site-packages/datasets/load.py", line 2594, in load_dataset
    builder_instance = load_dataset_builder(
  File "/home/youminkk/miniconda3/envs/RoG/lib/python3.10/site-packages/datasets/load.py", line 2266, in load_dataset_builder
    dataset_module = dataset_module_factory(
  File "/home/youminkk/miniconda3/envs/RoG/lib/python3.10/site-packages/datasets/load.py", line 1916, in dataset_module_factory
    raise FileNotFoundError(
FileNotFoundError: Couldn't find a dataset script at /home/youminkk/Paper/RoG/datasets/joint_training/qa/webqsp/webqsp.py or any data file in the same directory.

I checked the code in generate_explanation_results.py and found the section where the data path is specified.

## Line 15 - 19

save_dir = "datasets/joint_training/ExplainQAData"
split="train"
model_max_length = 1024
data_list = ['webqsp', 'cwq']
data_path = "/home/lluo/projects/KIT/data/KGQA"

How should I modify this part? For reference, the preprocessing steps were all completed successfully. Could you kindly let me know the correct path I should modify?

Have the parameters of llm's input embedding been tuned?

Thank you for providing the code. In the paper, the introduction of new tokens, marked as <\path>, is mentioned. I have a question regarding the tuning of input embeddings for the language model (llm) parameters. I noticed in the training code, specifically within the get_input_embeddings().parameters(), the requires_grad property is not explicitly set to true. Could you please clarify the necessity for this tuning?

Knowledge graph

Hi there, thanks for your job.
But I have a little question. You mentioned that there are two kinds of optimization(Planning optimization & Retrieval-reasoning optimization.) in the process of generating relation paths in your paper.

In planning optimization, you aim to distill the knowledge from KGs into LLMs to generate relation paths as plans.

Does creating a relation path through prompting in LLM, instead of actually using a knowledge graph, carry the same meaning as using a knowledge graph?
But, in the code, it seems like 'subgraph' is used in conjunction with 'prompt'....

Thank you.

Hits@1 Evaluation

Hi, great work!

I have a question about the Hits@1 evaluation. In file evaluation results.py, Hits is evaluated as follows:

def eval_hit(prediction, answer):
    for a in answer:
        if match(prediction, a):
            return 1
    return 0
hit = eval_hit(prediction_str, answer)

The result at Wepbqsp is: 86.36% (Hits)

However, this snippet compares every answer with all $k$ generated predictions (Hits@k), and not the top-1 prediction. If I use the top-1 prediction only (similar to competing methods):

def eval_hit1(prediction, answer):
    for a in answer:
        if match(prediction[0], a):
            return 1
    return 0
hit = eval_hit1(prediction, answer)

The result at Webqsp drops: 80.34% (Hits@1)

Is my understanding about the evaluation correct or am I missing something? Thank you!

Some questions about settings of experiments

Hi there,
I am trying to reproduce your results. Here are some questions i am curious about:

In Table 2 , in CWQ datatset, RoG achieves 62.6 on Hit@1 and 56.2 on F1, which is great.
But in Table 4, the 'LLaMA2-Chat-7B + RoG Planning ' gets 56.41 on Hit@1 in CWQ (even better than chatGPT?), did you finetune this model on reasoning setting ? if so , what's the difference between this setting and the original RoG , and the results are different (62.6 and 56.41 on Hit@1)?

Thank you for your precious time!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.