Comments (11)
Hi, I found a workaround for the problem. You just need to replace the the following line
torchdrug/torchdrug/tasks/retrosynthesis.py
Line 182 in 7fed9a5
reactant, product = batch["graph"]
graph = product.directed()
size = graph.num_edges + graph.num_nodes
Then, the following code snippet will predict the reactants for customized targets.
import torch
from torchdrug import datasets, data, models, tasks, utils
reaction_dataset = datasets.USPTO50k("molecule-datasets",
node_feature="center_identification",
kekulize=True)
synthon_dataset = datasets.USPTO50k("molecule-datasets", as_synthon=True,
node_feature="synthon_completion",
kekulize=True)
torch.manual_seed(1)
reaction_train, reaction_valid, reaction_test = reaction_dataset.split()
torch.manual_seed(1)
synthon_train, synthon_valid, synthon_test = synthon_dataset.split()
from torchdrug import core, models, tasks
reaction_model = models.RGCN(input_dim=reaction_dataset.node_feature_dim,
hidden_dims=[256, 256, 256, 256, 256, 256],
num_relation=reaction_dataset.num_bond_type,
concat_hidden=True)
reaction_task = tasks.CenterIdentification(reaction_model,
feature=("graph", "atom", "bond"))
synthon_model = models.RGCN(input_dim=synthon_dataset.node_feature_dim,
hidden_dims=[256, 256, 256, 256, 256, 256],
num_relation=synthon_dataset.num_bond_type,
concat_hidden=True)
synthon_task = tasks.SynthonCompletion(synthon_model, feature=("graph",))
reaction_task.preprocess(reaction_train, None, None)
synthon_task.preprocess(synthon_train, None, None)
task = tasks.Retrosynthesis(reaction_task, synthon_task, center_topk=2,
num_synthon_beam=5, max_prediction=10)
optimizer = torch.optim.Adam(task.parameters(), lr=1e-3)
solver = core.Engine(task, reaction_train, None, None,
optimizer, gpus=[0], batch_size=32)
solver.load("g2gs_reaction_model.pth", load_optimizer=False)
solver.load("g2gs_synthon_model.pth", load_optimizer=False)
smiles_list = ["CCOC(=O)C"]
molecules = data.PackedMolecule.from_smiles(smiles_list, node_feature="center_identification",
kekulize=True)
dummy_reaction_type = torch.zeros(len(smiles_list), dtype=torch.long)
batch = {"graph": (molecules, molecules), "reaction": dummy_reaction_type}
batch = utils.cuda(batch)
predictions, num_prediction = task.predict(batch)
for i in range(len(predictions)):
print(predictions[i].to_smiles())
The output is
CCOC(C)=O
CC.CC(=O)O
CC(=O)O.CCB(O)O
CC(=O)O.CCI
CC.O=C1CO1
CC(=O)OCCBr
CC.OC1CO1
O=C1C=CCO1
CC(=O)O.CCB1OB(C)OB(C)O1
CCB(O)O.O=C1CO1
We will update the codes once we have tested them carefully.
from torchdrug.
Thanks for your information. I think there are some flaws when tasks.Retrosynthesis.forward
operates on samples without ground truth information. It assumes ground truth is always present in the batch for computing the metrics. @shichence Could you please help fix that?
from torchdrug.
I check the code of task.predict
and I believe the output is ranked by their logp
in descending order.
And note each graph has a logps
attribute (graph.logps
) and you can rank these predictions by yourself.
from torchdrug.
Adapted from the retrosynthesis tutorial.
You need to use a center identification task without the reaction feature, otherwise you need to provide the reaction type during inference.
task = tasks.Retrosynthesis(reaction_task, synthon_task, ...)
smiles_list = ["....", "...."]
molecules = data.PackedMolecule.from_smiles(smiles_list, node_feature="center_identification",
kekulize=True)
batch = {"graph": molecules}
batch = utils.cuda(batch)
predictions, num_prediction = task.predict(batch)
Let me know if you have questions about the code.
from torchdrug.
@KiddoZhu Thanks for getting back to me!
I was getting some error when inference with product/target only; now I know it's the reaction type missing. :)
I would like to clarify a bit with you regarding use a center identification task without the reaction feature. Are you referring to something like this:
reactant, product = reaction_dataset[i]["graph"]
product_only_dataset[i]["graph"] = product
and then do the center identification task training?
and for the synthon task, do both reactant and synthon stay in the training data?
Thank you in advance!
from torchdrug.
You don't have to tweak the dataset. If you don't use reaction feature in center identification & synthon completion, then they will automatically ignore batch["reaction"]
in training and inference. The two tasks defined in the tutorial already satisfy this situation.
For the synthon task, the dataset should be loaded with argument as_synthon=True
, then each sample batch["graph"]
is a pair of reactant and synthon. See the visualization in our tutorial for more details.
from torchdrug.
Thanks for getting back to me. I followed your previous snippet for inference but got an error ValueError: not enough values to unpack (expected 2, got 1)
. It's still looking for reactant
.
Maybe I misunderstood what you meant previous so let me step back and clarify again, sorry.
so I've followed the retrosynthesis tutorial and was able to reproduce the notebook. I've got both "g2gs_reaction_model.pth and "g2gs_synthon_model.pth"
ready. But seems like for inference, we don't need to load the solver?
After that, here's the snippet I ran for inference with product only:
reaction_model = models.RGCN(input_dim=reaction_dataset.node_feature_dim,
hidden_dims=[256, 256, 256, 256, 256, 256],
num_relation=reaction_dataset.num_bond_type,
concat_hidden=True)
reaction_task = tasks.CenterIdentification(reaction_model, feature=("graph", "atom", "bond"))
synthon_model = models.RGCN(input_dim=synthon_dataset.node_feature_dim,
hidden_dims=[256, 256, 256, 256, 256, 256],
num_relation=synthon_dataset.num_bond_type,
concat_hidden=True)
synthon_task = tasks.SynthonCompletion(synthon_model, feature=("graph",))
reaction_task.preprocess(reaction_train, None, None)
synthon_task.preprocess(synthon_train, None, None)
task = tasks.Retrosynthesis(reaction_task, synthon_task, center_topk=2,
num_synthon_beam=5, max_prediction=10)
smiles_list = ["C1=CC=CC=C1"]
molecules = data.PackedMolecule.from_smiles(smiles_list, node_feature="center_identification",
kekulize=True)
batch = {"graph": molecules}
batch = utils.cuda(batch)
task.predict(batch)
I got the error ValueError: not enough values to unpack (expected 2, got 1)
on task.predict(batch)
My understanding is that I don't need to train the task with any modification of the dataset, right?
Thank you in advance!
from torchdrug.
Oh yes. I checked the code and it requires batch["graph"]
to be a tuple of reactants and products.
I feel the easiest work around is to use batch = {"graph": (molecules, molecules)}
, as the ground truth reactants have no effect during the inference.
from torchdrug.
so I used a tuple as suggested
smiles_list = ["C1=CC=CC=C1"]
molecules = data.PackedMolecule.from_smiles(smiles_list, node_feature="center_identification",
kekulize=True)
batch = {"graph": (molecules, molecules)}
Could the molecules in the tuple be the same compounds?
but got an error regarding AttributeError: 'PackedMolecule' object has no attribute 'edge_label'
on task.predict(batch)
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
/tmp/ipykernel_8415/3190283348.py in <module>
6 batch = {"graph": (molecules, molecules)}
7 batch = utils.cuda(batch)
----> 8 task.predict(batch)
/anaconda/envs/td/lib/python3.7/site-packages/torchdrug-0.1.1-py3.7.egg/torchdrug/tasks/retrosynthesis.py in predict(self, batch, all_loss, metric)
1091
1092 def predict(self, batch, all_loss=None, metric=None):
-> 1093 synthon_batch = self.center_identification.predict_synthon(batch, self.center_topk)
1094
1095 synthon = synthon_batch["synthon"]
/anaconda/envs/td/lib/python3.7/site-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
26 def decorate_context(*args, **kwargs):
27 with self.__class__():
---> 28 return func(*args, **kwargs)
29 return cast(F, decorate_context)
30
/anaconda/envs/td/lib/python3.7/site-packages/torchdrug-0.1.1-py3.7.egg/torchdrug/tasks/retrosynthesis.py in predict_synthon(self, batch, k)
180 """
181 pred = self.predict(batch)
--> 182 target, size = self.target(batch)
183 logp = functional.variadic_log_softmax(pred, size)
184
/anaconda/envs/td/lib/python3.7/site-packages/torchdrug-0.1.1-py3.7.egg/torchdrug/tasks/retrosynthesis.py in target(self, batch)
111 graph = product.directed()
112
--> 113 target = self._collate(graph.edge_label, graph.node_label, graph)
114 size = graph.num_edges + graph.num_nodes
115 return target, size
AttributeError: 'PackedMolecule' object has no attribute 'edge_label'
Should I use an arbitrary reactants for batch = {"graph": (molecules, reactants)}
?
Thanks!
from torchdrug.
Thank you @shichence!
Are the outputs ranked by a certain metric? Is the metric saved in predictions
?
from torchdrug.
Thanks for getting back to me. This is really helpful!
from torchdrug.
Related Issues (20)
- Passing CIF file directly into TorchDrug Protein?
- To load huge dataset, and train the model
- example for tasks.MultipleBinaryClassification HOT 2
- Statistics and data splitting scheme on PPI datasets HOT 2
- Nan error while using GraphAF to optimize QED
- Small issue in atom2valence dictionary iteration in `Molecule.is_valid`
- torchdrug.patch.py breaks use of ConcatDataset HOT 1
- [Note] Dead lock when running `layers.GraphIsomorphismConv` (issue located in `sparse_coo_tensor`)
- graph with global features, and mydataset
- atom_feature="symbol" is only available for class GCPNGeneration() ? HOT 1
- class Molecule(Graph), definition of "self.atom2valence"
- TypeError during validation
- Segmentation fault HOT 1
- issues with reproducing pLogP benchmark for GPCN
- subprocess.CalledProcessError: Command '['c++', '-v']' returned non-zero exit status 1. HOT 1
- Impossible to reproduce pretrain tutorial
- ESM, max length issues using ESM-gearnet-serial model HOT 2
- Incompatible with huggingface transformers?
- Quick Start problem. RuntimeError: Error building extension 'torch_ext'
- [Solved: share of training data kept in task/KnowledgeGraphCompletion/preprocess]
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from torchdrug.