Coder Social home page Coder Social logo

Comments (11)

shichence avatar shichence commented on July 30, 2024 3

Hi, I found a workaround for the problem. You just need to replace the the following line

target, size = self.target(batch)
by

reactant, product = batch["graph"]
graph = product.directed()
size = graph.num_edges + graph.num_nodes

Then, the following code snippet will predict the reactants for customized targets.

import torch
from torchdrug import datasets, data, models, tasks, utils

reaction_dataset = datasets.USPTO50k("molecule-datasets",
                                     node_feature="center_identification",
                                     kekulize=True)
synthon_dataset = datasets.USPTO50k("molecule-datasets", as_synthon=True,
                                    node_feature="synthon_completion",
                                    kekulize=True)

torch.manual_seed(1)
reaction_train, reaction_valid, reaction_test = reaction_dataset.split()
torch.manual_seed(1)
synthon_train, synthon_valid, synthon_test = synthon_dataset.split()

from torchdrug import core, models, tasks

reaction_model = models.RGCN(input_dim=reaction_dataset.node_feature_dim,
                    hidden_dims=[256, 256, 256, 256, 256, 256],
                    num_relation=reaction_dataset.num_bond_type,
                    concat_hidden=True)
reaction_task = tasks.CenterIdentification(reaction_model,
                                           feature=("graph", "atom", "bond"))
synthon_model = models.RGCN(input_dim=synthon_dataset.node_feature_dim,
                            hidden_dims=[256, 256, 256, 256, 256, 256],
                            num_relation=synthon_dataset.num_bond_type,
                            concat_hidden=True)
synthon_task = tasks.SynthonCompletion(synthon_model, feature=("graph",))

reaction_task.preprocess(reaction_train, None, None)
synthon_task.preprocess(synthon_train, None, None)
task = tasks.Retrosynthesis(reaction_task, synthon_task, center_topk=2,
                            num_synthon_beam=5, max_prediction=10)

optimizer = torch.optim.Adam(task.parameters(), lr=1e-3)
solver = core.Engine(task, reaction_train, None, None,
                     optimizer, gpus=[0], batch_size=32)
solver.load("g2gs_reaction_model.pth", load_optimizer=False)
solver.load("g2gs_synthon_model.pth", load_optimizer=False)

smiles_list = ["CCOC(=O)C"]
molecules = data.PackedMolecule.from_smiles(smiles_list, node_feature="center_identification",
                                            kekulize=True)
dummy_reaction_type = torch.zeros(len(smiles_list), dtype=torch.long)
batch = {"graph": (molecules, molecules), "reaction": dummy_reaction_type}
batch = utils.cuda(batch)
predictions, num_prediction = task.predict(batch)
for i in range(len(predictions)):
    print(predictions[i].to_smiles())

The output is

CCOC(C)=O
CC.CC(=O)O
CC(=O)O.CCB(O)O
CC(=O)O.CCI
CC.O=C1CO1
CC(=O)OCCBr
CC.OC1CO1
O=C1C=CCO1
CC(=O)O.CCB1OB(C)OB(C)O1
CCB(O)O.O=C1CO1

We will update the codes once we have tested them carefully.

from torchdrug.

KiddoZhu avatar KiddoZhu commented on July 30, 2024 2

Thanks for your information. I think there are some flaws when tasks.Retrosynthesis.forward operates on samples without ground truth information. It assumes ground truth is always present in the batch for computing the metrics. @shichence Could you please help fix that?

from torchdrug.

shichence avatar shichence commented on July 30, 2024 1

I check the code of task.predict and I believe the output is ranked by their logp in descending order.
And note each graph has a logps attribute (graph.logps) and you can rank these predictions by yourself.

from torchdrug.

KiddoZhu avatar KiddoZhu commented on July 30, 2024

Adapted from the retrosynthesis tutorial.

You need to use a center identification task without the reaction feature, otherwise you need to provide the reaction type during inference.

task = tasks.Retrosynthesis(reaction_task, synthon_task, ...)
smiles_list = ["....", "...."]
molecules = data.PackedMolecule.from_smiles(smiles_list, node_feature="center_identification",
                                            kekulize=True)
batch = {"graph": molecules}
batch = utils.cuda(batch)
predictions, num_prediction = task.predict(batch)

Let me know if you have questions about the code.

from torchdrug.

juliachen123 avatar juliachen123 commented on July 30, 2024

@KiddoZhu Thanks for getting back to me!
I was getting some error when inference with product/target only; now I know it's the reaction type missing. :)
I would like to clarify a bit with you regarding use a center identification task without the reaction feature. Are you referring to something like this:

reactant, product = reaction_dataset[i]["graph"]

product_only_dataset[i]["graph"] = product

and then do the center identification task training?

and for the synthon task, do both reactant and synthon stay in the training data?

Thank you in advance!

from torchdrug.

KiddoZhu avatar KiddoZhu commented on July 30, 2024

You don't have to tweak the dataset. If you don't use reaction feature in center identification & synthon completion, then they will automatically ignore batch["reaction"] in training and inference. The two tasks defined in the tutorial already satisfy this situation.

For the synthon task, the dataset should be loaded with argument as_synthon=True, then each sample batch["graph"] is a pair of reactant and synthon. See the visualization in our tutorial for more details.

from torchdrug.

juliachen123 avatar juliachen123 commented on July 30, 2024

Thanks for getting back to me. I followed your previous snippet for inference but got an error ValueError: not enough values to unpack (expected 2, got 1). It's still looking for reactant.

Maybe I misunderstood what you meant previous so let me step back and clarify again, sorry.

so I've followed the retrosynthesis tutorial and was able to reproduce the notebook. I've got both "g2gs_reaction_model.pth and "g2gs_synthon_model.pth" ready. But seems like for inference, we don't need to load the solver?
After that, here's the snippet I ran for inference with product only:

reaction_model = models.RGCN(input_dim=reaction_dataset.node_feature_dim,
                    hidden_dims=[256, 256, 256, 256, 256, 256],
                    num_relation=reaction_dataset.num_bond_type,
                    concat_hidden=True)
reaction_task = tasks.CenterIdentification(reaction_model, feature=("graph", "atom", "bond"))

synthon_model = models.RGCN(input_dim=synthon_dataset.node_feature_dim,
                            hidden_dims=[256, 256, 256, 256, 256, 256],
                            num_relation=synthon_dataset.num_bond_type,
                            concat_hidden=True)
synthon_task = tasks.SynthonCompletion(synthon_model, feature=("graph",))

reaction_task.preprocess(reaction_train, None, None)
synthon_task.preprocess(synthon_train, None, None)
task = tasks.Retrosynthesis(reaction_task, synthon_task, center_topk=2,
                            num_synthon_beam=5, max_prediction=10)

smiles_list = ["C1=CC=CC=C1"]
molecules = data.PackedMolecule.from_smiles(smiles_list, node_feature="center_identification",
                                            kekulize=True)
batch = {"graph": molecules}
batch = utils.cuda(batch)
task.predict(batch)

I got the error ValueError: not enough values to unpack (expected 2, got 1) on task.predict(batch)
My understanding is that I don't need to train the task with any modification of the dataset, right?

Thank you in advance!

from torchdrug.

KiddoZhu avatar KiddoZhu commented on July 30, 2024

Oh yes. I checked the code and it requires batch["graph"] to be a tuple of reactants and products.

I feel the easiest work around is to use batch = {"graph": (molecules, molecules)}, as the ground truth reactants have no effect during the inference.

from torchdrug.

juliachen123 avatar juliachen123 commented on July 30, 2024

so I used a tuple as suggested

smiles_list = ["C1=CC=CC=C1"]
molecules = data.PackedMolecule.from_smiles(smiles_list, node_feature="center_identification",
                                            kekulize=True)
batch = {"graph": (molecules, molecules)}

Could the molecules in the tuple be the same compounds?

but got an error regarding AttributeError: 'PackedMolecule' object has no attribute 'edge_label' on task.predict(batch)

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
/tmp/ipykernel_8415/3190283348.py in <module>
      6 batch = {"graph": (molecules, molecules)}
      7 batch = utils.cuda(batch)
----> 8 task.predict(batch)

/anaconda/envs/td/lib/python3.7/site-packages/torchdrug-0.1.1-py3.7.egg/torchdrug/tasks/retrosynthesis.py in predict(self, batch, all_loss, metric)
   1091 
   1092     def predict(self, batch, all_loss=None, metric=None):
-> 1093         synthon_batch = self.center_identification.predict_synthon(batch, self.center_topk)
   1094 
   1095         synthon = synthon_batch["synthon"]

/anaconda/envs/td/lib/python3.7/site-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
     26         def decorate_context(*args, **kwargs):
     27             with self.__class__():
---> 28                 return func(*args, **kwargs)
     29         return cast(F, decorate_context)
     30 

/anaconda/envs/td/lib/python3.7/site-packages/torchdrug-0.1.1-py3.7.egg/torchdrug/tasks/retrosynthesis.py in predict_synthon(self, batch, k)
    180         """
    181         pred = self.predict(batch)
--> 182         target, size = self.target(batch)
    183         logp = functional.variadic_log_softmax(pred, size)
    184 

/anaconda/envs/td/lib/python3.7/site-packages/torchdrug-0.1.1-py3.7.egg/torchdrug/tasks/retrosynthesis.py in target(self, batch)
    111         graph = product.directed()
    112 
--> 113         target = self._collate(graph.edge_label, graph.node_label, graph)
    114         size = graph.num_edges + graph.num_nodes
    115         return target, size

AttributeError: 'PackedMolecule' object has no attribute 'edge_label'

Should I use an arbitrary reactants for batch = {"graph": (molecules, reactants)}?

Thanks!

from torchdrug.

juliachen123 avatar juliachen123 commented on July 30, 2024

Thank you @shichence!
Are the outputs ranked by a certain metric? Is the metric saved in predictions?

from torchdrug.

juliachen123 avatar juliachen123 commented on July 30, 2024

Thanks for getting back to me. This is really helpful!

from torchdrug.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.