theislab / chemcpa Goto Github PK
View Code? Open in Web Editor NEWCode for "Predicting Cellular Responses to Novel Drug Perturbations at a Single-Cell Resolution", NeurIPS 2022.
Home Page: https://arxiv.org/abs/2204.13545
License: MIT License
Code for "Predicting Cellular Responses to Novel Drug Perturbations at a Single-Cell Resolution", NeurIPS 2022.
Home Page: https://arxiv.org/abs/2204.13545
License: MIT License
seml.get_results('cpa_graphs_15', to_data_frame=True)
pd.Dataframe
r2
(all genes, DE genes, top-x DE genes)?(real - control)
vs. (predicted - control)
Ideally we can standardise the evaluation → Figure design
For this it would be great to get some scripts & examples!
Similar to the perturbation disentanglement, the calculation of optimal for covariates
is probably wrong and should be updated.
The Vanilla Model with ~6M parameters is outperforming most of our other model, however these other models have maximum ~2M parameters.
It might be worth scheduling some runs with very large models, to see how they perform.
If we check the perturbation disentanglement to make sure we're not accidentally training an autoencoder, then we can run a regular test-set reconstruction loss.
Otherwise we will not be able to discern models that perform well averaged-across-cells (eg by predicting the mean) from models that perform well on each individual cell.
As a sanity check baseline we should add an embedding that's just a zero table and isn't updated during training.
Should go into the analyze_lincs_all_embeddings_hparam.ipynb
. We use this to show that the counterfactual prediction is working correctly and that the information coming from the embeddings is actually useful.
Test how much pretraining on LINCS helps with improving OOD drug prediction on Trapnell.
Would allow accurate predictions of single-cell response to unseen drugs, without spending more money on the datasets.
The pretrained models perform better than the non-pretrained model.
The model that has seen the hold-out drugs on LINCS performs better than the pre-trained model that hasn't seen the drugs before.
A class that wraps the specific predictor (CPA, scVI, cVAE, ...) and exposes train and predict functions.
<TO BE EXTENDED>
I think SMILES should be canonical, for this
Observing how the transformed drug embeddings cluster, and comparing the clustering to Vanilla, to the clustering of Trapnell gene expressions (supp figure in Sciplex paper), to the clustering of untransformed drug embeddings, to the clustering of pretrained-then-finetuned models.
Which precise plots we'll include in the paper is unclear.
Overall the goal is to give credence to the claim that the chemical embeddings are meaningful and contribute to lowering the CCPA loss.
Currently we're only planning to perform this experiment for Trapnell, as we don't have information about the drug pathways for LINCS.
Stop training when we get NaNs.
Looking at the results in chemical_CPA/simon/plot_sweep_results.ipynb
Hypothesis: The drug disentanglement might be easier for NN based embeddings as these are chemically motivated.
Potential experiment: Make classifier stronger including more layer than just the linear one for the logistic regression.
Currently we have a very simple Seq2Seq model, which nevertheless seems to be performing well. This model can be strongly improved, through making it more similar to the VAE presented in https://pubs.acs.org/doi/abs/10.1021/acscentsci.7b00572
I'm not sure if we should invest the time into add yet another embedding. However, this can probably be done within a day by using a pre-existing implementation.
Needs:
The implementation in moses should work for us. We need to be careful with the KL-divergence, as in their experiments they mostly care about generative (ie sampling performance). It may be useful for our case to use smaller βs (KL divergence contributes less to the overall loss), which will reduce sampling performance but increase reconstruction performance (and hopefully lead to a more meaningful latent space).
The published repo does not contain sweeps for the LINCS dataset.
We should agree on hparams for both Compert
and DrugEmb
classes.
Add MC Dropout for minimal uncertainty in CPA and GNNs
According to Mo, the subsampling in the Trapnell preprocessing (removing 50% of the rows) is mainly done for Performance reasons. Our code should be fast enough, so we can remove it and get more data.
Class imbalance should be incorporated into the BCE loss. Example on Trapnell:
> adata.obs["condition"].value_counts()[:10]
control 6464
GSK-LSD1 1868
BRD4770 1868
Baricitinib 1862
Entacapone 1853
RG108 1852
WP1066 1851
Curcumin 1850
Capecitabine 1849
Mesna 1847
Name: condition, dtype: int64
> adata.obs["condition"].value_counts()[-10:]
Rigosertib 984
Luminespib 980
Tozasertib 975
Mocetinostat 949
Alvespimycin 930
AT9283 910
Patupilone 757
Flavopiridol 693
Epothilone 583
YM155 394
Name: condition, dtype: int64
This is not super important, as the dataset is not actually very imbalanced. But it will make the adversarial loss more meaningful.
evaluate
still takes ~2h on LINCS (1h for evaluate, 1h for evaluate logfold). This makes it impossible to run the evaluation frequently, instead we can only run the evaluation at the end of a full training run. There is no good reason why this should take 2h, it can probably be done in a few minutes.
How to deal with this:
Problems: It's hard to guess how much effort this will be, could be 4h, could be 3 days. For now we can already do some runs with just a single evaluation at the end, though this issue will have to be fixed to enable HParam tuning and creating proper loss plots.
I'm having problems with getting canonicalization to work.
print(Chem.CanonSmiles("N[C@]12C[C@H]3C[C@@H](C[C@H](C3)C1)C2.N[C@]12C[C@H]3C[C@@H](C[C@H](C3)C1)C2"))
print(Chem.CanonSmiles(Chem.CanonSmiles("N[C@]12C[C@H]3C[C@@H](C[C@H](C3)C1)C2.N[C@]12C[C@H]3C[C@@H](C[C@H](C3)C1)C2")))
print(Chem.CanonSmiles(Chem.CanonSmiles(Chem.CanonSmiles("N[C@]12C[C@H]3C[C@@H](C[C@H](C3)C1)C2.N[C@]12C[C@H]3C[C@@H](C[C@H](C3)C1)C2"))))
Results in:
N[C@]12C[C@@H]3C[C@H](C[C@H](C3)C1)C2.N[C@]12C[C@@H]3C[C@H](C[C@H](C3)C1)C2
N[C@]12C[C@H]3C[C@@H](C[C@H](C3)C1)C2.N[C@]12C[C@H]3C[C@@H](C[C@H](C3)C1)C2
N[C@]12C[C@@H]3C[C@H](C[C@H](C3)C1)C2.N[C@]12C[C@@H]3C[C@H](C[C@H](C3)C1)C2
As this is a loop, we can never get to a canoncial representation. Similar things happen if I use rdkit.Chem.MolToSmiles(rdkit.Chem.MolFromSmiles(smiles), canonical=True)
.
What's the correct way to canonicalize a molecule? @MxMstrmn @M0hammadL
This leads to all kinds of indexing errors when trying to encode molecules via the GROVER embedding. I can make it work without canonicalizing, but it seems like this should be a solveable problem.
Checklist:
embeddings/lincs_trapnell.smiles
compert/embedding.py: get_chemical_representation
./storage/groups/ml01/projects/2021_chemicalCPA_leon.hetzel/embeddings/
There is some unused code (eg much of train.py
, api.py
, ...) that should be deleted using a number of small commits (so we can revert them if the code turns out to still be needed).
Reminder for @MxMstrmn to check how the LINCS train / test split was calculated (column split1
in lincs_full_smiles.h5ad
) and whether it should be adjusted.
In LINCS we have 979 genes, in Trapnell there are 5000. However the naive overlap (via plain text matching) is only 89. We have to define & implement a strategy for dealing with this during the transfer.
Test whether we can accurately predict the gene expressions for drugs that haven't been observed during training.
This has nothing to do with pretraining, but answers the question: Do our molecular embeddings allow us to accurately predict gene responses for unseen drugs?
I am having issues setting up the chemCPA environment. I cloned the repo onto my machine and executed the command: 'conda env create -f environment.yml.' I get the following error:
Solving environment: failed
ResolvePackageNotFound:
So I removed moses and RDKit from the .yml file. I then run 'conda env create -f environment.yml.' again, with the intention to install MOSES / Rdkit manually after it is done running via these instructions found on the MOSES github repo:
'The simplest way to install MOSES (models and metrics) is to install RDKit: conda install -yq -c rdkit rdkit and then install MOSES (molsets) from pip (pip install molsets). If you want to use LatentGAN, you should also install additional dependencies using bash install_latentgan_dependencies.sh.
If you are using Ubuntu, you should also install sudo apt-get install libxrender1 libxext6 for RDKit.'
However, it gets stuck on 'Solving Environment.' Please let me know how I should proceed!
2000 instead of the current 977 genes
Part of #76
Get the combination dataset from Alex for checking drug interaction (synergetic, etc.) and curate the data
After loading a dataset the current code generates a OHE for the drugs. This takes way too long on LINCS, since there are 17K drugs and since the OHE generation runs in a single CPU thread.
This should be sped up, either by removing the OHE encoding and working with indices instead, or by somehow speeding up the OHE generation.
I'm not happy with the results of the large sweep that we ran on LINCS. Mainly:
I'd do it like this:
I think it's important to get this right before we design too many other experiments.
If it turns out that the transfer learning doesn't help for improving Trapnell scores, then there's no use in implementing #62 for example.
We can then still runs the other experiments like #67 and hope to see improvements there.
If that doesn't work either, we can check whether the model is at least useful for predicting the effects of drugs that it hasn't seen.
That's just for covering our bases in the worst case, I think with some tweaking the Transfer learning will work.
When training on LINCS (on our new random_split
), after the epochs have finished, we run a full evaluate. This runs into an error:
2021-12-28 14:25:33 (INFO): Running the full evaluation (Epoch:300)
Number of different r2 computations: 120Number of different r2 computations: 34158
2021-12-28 14:33:35 (ERROR): Failed after 4:54:16!Traceback (most recent call last):
File "/home/icb/simon.boehm/miniconda3/envs/chemical_CPA/lib/python3.9/site-packages/sacred/experiment.py", line 312, in run_commandline return self.run(
File "/home/icb/simon.boehm/miniconda3/envs/chemical_CPA/lib/python3.9/site-packages/sacred/experiment.py", line 276, in run run()
File "/home/icb/simon.boehm/miniconda3/envs/chemical_CPA/lib/python3.9/site-packages/sacred/run.py", line 238, in __call__ self.result = self.main_function(*args)
File "/home/icb/simon.boehm/miniconda3/envs/chemical_CPA/lib/python3.9/site-packages/sacred/config/captured_function.py", line 42, in captured_function
result = wrapped(*args, **kwargs) File "/tmp/de393a80-c152-4a11-8880-4b02eb8b1780/compert/seml_sweep_icb.py", line 321, in train
return experiment.train() File "/home/icb/simon.boehm/miniconda3/envs/chemical_CPA/lib/python3.9/site-packages/sacred/config/captured_function.py", line 42, in captured_function result = wrapped(*args, **kwargs)
File "/tmp/de393a80-c152-4a11-8880-4b02eb8b1780/compert/seml_sweep_icb.py", line 256, in train evaluation_stats = evaluate(
File "/tmp/de393a80-c152-4a11-8880-4b02eb8b1780/compert/train.py", line 322, in evaluate "ood": evaluate_r2(
File "/tmp/de393a80-c152-4a11-8880-4b02eb8b1780/compert/train.py", line 241, in evaluate_r2 np.array(dataset.de_genes[cell_drug_dose_comb])
KeyError: 'A375_DMSO_0.1'
I briefly looked at it, and A375_DMSO_0.1 actually doesn't exist, neither does A375_control_0.1 or any variation. Don't quite know what the underlying cause is, but we should fix it.
It's not a super big issue, but it causes all seml runs to fail without having any results recorded inside the MongoDB.
We need to invest more effort to get the covariates embeddings to properly transfer from LINCS ⇒ SciPlex. Picture to illustrate:
The solution is to save the string name of the covariate with the embedding into the state_dict
, st we can keep the same mapping. The way we currently have it implemented we just pull a basically random embedding from the top of the table, so this is not super useful.
Needs to support:
Currently the CPA Adv predictor outputs a probability distribution over all possible drugs, given the basal state. This works well for datasets like Trapnell (188 drugs) but for LINCS (>17K drugs) it is not feasible because there aren't enough samples of each drug.
Therefore this needs to be adapted:
Different approaches are possible during finetuning:
The final CPA model is used to predict counterfactuals. For this to work all information about the drug has to have been removed from the latent basal state. This may not fully be the case when we predict just the cluster assignment. Example: If there are potent and less potent drugs in each cluster, then a notion of potency may remain in the latent basal state even though the cluster cannot be predicted anymore. If we use strategy 1) during finetuning we'll definitely ensure "latent basal state drug ambivalence" for the final model. An alternative approach may be to have the Adv Predictor predict the drug embedding directly (using a smaller, <1000 dim drug embedding) and using a cosine distance to the "true" drug embedding as the adversarial loss. This is how BERT models predict words. For now the cluster strategy seems more promising.
Add a class that takes in a list of canonical SMILES and returns a fixed-sized embedding of the molecules (torch.Embedding
).
# each model will inherit from this interface
class ChemicalRepresentation:
@classmethod
def dim(cls):
return: int # the number of latent dimensions for this model
def __init__(self, dataset: str):
# loads the model into memory
def encode(self, molecules: list[str]):
# encodes the given list of SMILES strings into a torch Embedding
returns emb: torch.Embedding where emb.shape[0] == len(molecules)
def decode(self, emb: torch.Embedding):
returns list[str]: list of SMILES
# maps the string to the class
EMBEDDING_MODELS = {
"GROVER": GroverRepresentation
}
# for the given list of SMILES strings, returns a dataframe with two columns:
# column 1: SMILES: the smiles string
# column 2: Embedding: numpy array
# Casting back to a torch tensor has to be done at the Dataloading-level eg in the get_item method.
def get_chemical_representation_df(molecules: list[str], embedding_model: str, dataset: str, cache_dir="datasets/embedding"):
if cache_dir is not None and (Path(cache_dir) / f"{embedding_model}_{dataset}_df.parquet")).exists():
# load the dataframe files and return it
else:
model = EMBEDDING_MODELS[embedding_model](dataset)
embedding = model.encode(molecules)
df = pandas.DataFrame.from_dict({"SMILE": molecules, "embedding": list(embedding)}).
if cache_dir is not None:
df.to_parquet(f"{embedding_model}_{dataset}_df.parquet")
return df
Looking at the results in chemical_CPA/simon/plot_sweep_results.ipynb
We observer that GROVER performs weaker than expected.
Hypothesis: We assume that this might be related to the size of the GROVER embedding which has 3400 dimensions
Potential experiments: Increase size of drug embedders to cope with the large input dimension.
Basic idea: We want to test how much pretraining on a smaller set of genes helps for increasing performance after finetuning on a larger amount of genes.
This is relevant biologically, since commonly a different set of genes is selected for single cell experiments.
Experiment steps:
Implementation steps:
Running the DrugEmb class with the molecular featurizer set to Pretrain
fails at computing the graphs from the SMILES strings.
At graph_from_smiles
it fails:
https://github.com/theislab/chemical_CPA/blob/0ef1363f762cef91f9a046b91827d190aa93225a/compert/data.py#L137-L146
This is the Traceback:
Traceback (most recent call last):
File "/home/icb/leon.hetzel/miniconda3/envs/py38/lib/python3.8/site-packages/sacred/experiment.py", line 312, in run_commandline
return self.run(
File "/home/icb/leon.hetzel/miniconda3/envs/py38/lib/python3.8/site-packages/sacred/experiment.py", line 276, in run
run()
File "/home/icb/leon.hetzel/miniconda3/envs/py38/lib/python3.8/site-packages/sacred/run.py", line 238, in __call__
self.result = self.main_function(*args)
File "/home/icb/leon.hetzel/miniconda3/envs/py38/lib/python3.8/site-packages/sacred/config/captured_function.py", line 42, in captured_function
result = wrapped(*args, **kwargs)
File "/tmp/26918/compert/seml_sweep_icb.py", line 266, in train
experiment = ExperimentWrapper()
File "/tmp/26918/compert/seml_sweep_icb.py", line 69, in __init__
self.init_all()
File "/tmp/26918/compert/seml_sweep_icb.py", line 152, in init_all
self.init_dataset()
File "/home/icb/leon.hetzel/miniconda3/envs/py38/lib/python3.8/site-packages/sacred/config/captured_function.py", line 42, in captured_function
result = wrapped(*args, **kwargs)
File "/tmp/26918/compert/seml_sweep_icb.py", line 84, in init_dataset
self.datasets, self.dataset = load_dataset_splits(
File "/tmp/26918/compert/data.py", line 281, in load_dataset_splits
dataset = Dataset(
File "/tmp/26918/compert/data.py", line 137, in __init__
graph_tuple = graph_from_smiles(
File "/tmp/26918/compert/helper.py", line 162, in graph_from_smiles
graph = smiles2graph(smiles)
File "/tmp/26918/compert/helper.py", line 148, in <lambda>
smiles2graph = lambda smiles: smiles_to_bigraph(
File "/home/icb/leon.hetzel/miniconda3/envs/py38/lib/python3.8/site-packages/dgllife/utils/mol_to_graph.py", line 369, in smiles_to_bigraph
return mol_to_bigraph(mol, add_self_loop, node_featurizer, edge_featurizer,
File "/home/icb/leon.hetzel/miniconda3/envs/py38/lib/python3.8/site-packages/dgllife/utils/mol_to_graph.py", line 269, in mol_to_bigraph
return mol_to_graph(mol, partial(construct_bigraph_from_mol, add_self_loop=add_self_loop),
File "/home/icb/leon.hetzel/miniconda3/envs/py38/lib/python3.8/site-packages/dgllife/utils/mol_to_graph.py", line 85, in mol_to_graph
g.ndata.update(node_featurizer(mol))
File "/home/icb/leon.hetzel/miniconda3/envs/py38/lib/python3.8/site-packages/dgllife/utils/featurizers.py", line 1289, in __call__
self._atomic_number_types.index(atom.GetAtomicNum()),
ValueError: 0 is not in list
Generate 3 new Trapnell splits, one for each of the 3 pathways (epigenetic, cell cycle, tyrosine kinase).
In each split, 10 drugs are left out.
For LINCS we need just one new split, where all 30 drugs are left out.
Part of #67
I started writing the code that does the transfer learning (loads the model pretrained on LINCS for finetuning on Trapnell) and ran into a stumbling block: LINCS has 978 genes, one of which isn't part of Trapnell. Further the ordering of the genes betw. lincs_full_smiles.h5ad
and trapnell_cpa.h5ad
is completely different.
Hence:
trapnell_cpa_lincs_genes.h5ad
which contains just the genes that are also part of LINCS, with the exact same orderingLINCS_full_smiles_trapnell_genes.h5ad
which contains just the genes that are also part of Trapnell, again with the same ordering.I think generating & storing the datasets is less error prone than trying to fix this in the code.
Can you do this @MxMstrmn? We can also talk about it tmrw.
As it stands there's no parameter finetuning in #85 or #84.
autoencoder_lr
, autoencoder_wd
batch_size
We should do this separately for the finetuned and from-scratch training.
dosers_lr
& dosers_wd
autoencoder_lr
(autencoder + drug embedder is updated using the same optimizer).I think (1) is important, as a good lr may make a difference for finetuning and as the classification task for the adversaries is pretty different on Trapnell. (2) is probably much less important, we could use it as a source of variation during the individual runs.
Measures
Potentially, investigate whether the evaluation is computed on the CPU and move it the GPU @siboehm
The RDKit embedding (fingerprint) as saved on the server has some Inf
and NaN
values, which triggers immediate early stopping.
Options:
On LINCS we have some problems with evaluations resulting in strange [0.0, 0.0, 0.0] results. There is this line:
https://github.com/theislab/chemical_CPA/blob/main/compert/train.py#L68
So probably the 0.0 come from some NaNs in the evaluation. We should remove this line. Now that the evaluation is much faster (17min without disentangle on LINCS) we can just run it more often and save intermediate checkpoints.
So far just a collection of ideas for a section in our paper, where we try to show that the low-dimensional embeddings (after the drug embedder transformation) are useful for later downstream tasks.
Including:
That is the data directory: home/icb/leon.hetzel/git/CPA_graphs/datasets/
This is still todo, correct? @MxMstrmn
We need to add the code to support architecture surgery. This covers scenarios when we have a pretrained model on X genes but only a subset of X in the new dataset on which we want to transfer. More specifically, this refers to a scenario where we only use say 500 from the 974 lincs genes during the transfer task.
docs/
, to replace the chemical_CPA.png
).chemCPA/plotting.py
file, if it's still needed we should add it back)Feel free to edit / comment @MxMstrmn
which tasks to choose for multi task learning :
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.