Comments (5)
Well I understand what happens now. The main problem is that when you load from jax, ExponentialMovingAverage, the ema initiallized incorrectly. You can see in train_openfold.py line307, there is:
if args.resume_from_ckpt:
if args.resume_model_weights_only:
# Load the checkpoint
if os.path.isdir(args.resume_from_ckpt):
sd = zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint(
args.resume_from_ckpt)
else:
sd = torch.load(args.resume_from_ckpt)
# Process the state dict
if 'module' in sd:
sd = {k[len('module.'):]: v for k, v in sd['module'].items()}
import_openfold_weights_(model=model_module, state_dict=sd)
elif 'state_dict' in sd:
import_openfold_weights_(
model=model_module, state_dict=sd['state_dict'])
else:
# Loading from pre-trained model
sd = {'model.'+k: v for k, v in sd.items()}
import_openfold_weights_(model=model_module, state_dict=sd)
logging.info("Successfully loaded model weights...")
Look, the variable 'state_dict' stored params used in ema. But if you check line 262:
def load_from_jax(self, jax_path):
model_basename = os.path.splitext(
os.path.basename(
os.path.normpath(jax_path)
)
)[0]
model_version = "_".join(model_basename.split("_")[1:])
logging.warning('load from version'+model_version)
import_jax_weights_(
self.model, jax_path, version=model_version
)
There is nothing about state_dict, so ema can only use original parameters which come from AF model initialization, and this is not trained.
There should be some better methods but I'm not a professional programmer, so I fix this by:
1.Add an attribute to ema
def __init__(self, model: nn.Module, decay: float):
"""
Args:
model:
A torch.nn.Module whose parameters are to be tracked
decay:
A value (usually close to 1.) by which updates are
weighted as part of the above formula
"""
super(ExponentialMovingAverage, self).__init__()
clone_param = lambda t: t.clone().detach()
self.params = tensor_tree_map(clone_param, model.state_dict())
self.decay = decay
self.device = next(model.parameters()).device
self.load_from_jax = False
2.Add a method to ema
def repair_params_when_load_from_jax(self, state_dict: OrderedDict) -> None:
for k in state_dict.keys():
self.params[k] = state_dict[k].clone()
3.Run this method once when load_from_jax is true and set it to false
def update(self, model: torch.nn.Module) -> None:
"""
Updates the stored parameters using the state dict of the provided
module. The module should have the same structure as that used to
initialize the ExponentialMovingAverage object.
"""
if self.load_from_jax:
self.load_from_jax = False
self.repair_params_when_load_from_jax(model.state_dict())
self._update_state_dict_(model.state_dict(), self.params)
4.When you load from jax, set load_from_jax to true
def load_from_jax(self, jax_path):
model_basename = os.path.splitext(
os.path.basename(
os.path.normpath(jax_path)
)
)[0]
model_version = "_".join(model_basename.split("_")[1:])
logging.warning('load from version'+model_version)
import_jax_weights_(
self.model, jax_path, version=model_version
)
self.ema.load_from_jax = True
from openfold.
Hi, have you solve this problem? I tried to print weight&bias from original AF parameters and ckpt after only one step train.
Almost all weights changed sharply and all bias are, even broken. I do not know why are these bias in e^-6 or e^-7 but they are definitely wrong.
from openfold.
@C-de-Furina Hi, I am suffering from same problem. I'm curious how did you make a ckpt format file from alphafold2 jax parameter. There is no script to convert jax parameter into openfold ckpt parameter in this repository. I also want to know how did you run the fine-tuning? Was it same in last issue you made? Thank you for your help!
from openfold.
@C-de-Furina 你好,我也遇到了同样的问题。我很好奇你是如何从 alphafold2 jax 参数制作出 ckpt 格式文件的。这个存储库中没有将 jax 参数转换为 openfold ckpt 参数的脚本。我也想知道你是如何运行微调的?你上次做的问题也是这样吗?谢谢你的帮助!
I'm still tring to find a way to create ckpt. Now I try to call "trainer.save_checkpoint("example.ckpt")" without any training but meet some troubles. Before this I created a ckpt by training from jax after only one step. So I can say, these parameters are broken at beginning. My training is in the same way as monomer except "--config_preset="model_5_multimer_v3""
Besides, you can try to use your trained ckpt to predict a homomer protein, like 8w7d, then you will see all chains almost completely overlaps, but this situation will not occur for heteromers. Thus, I think maybe there is some intrinsic defects in OF-multimer training.
from openfold.
Thank you for your kind explanation! I would try your solution and check whether problem solved.
from openfold.
Related Issues (20)
- Frequently failed in training. HOT 2
- Rigid.from_3_points comment HOT 1
- Colab broken by version skew HOT 1
- Docker build broken HOT 4
- Question: Can the geometry module and rigid_utils be converted to each other?
- Alignment error during inference HOT 1
- ModuleNotFoundError: No module named 'attn_core_inplace_cuda' HOT 3
- Enable Dropout in inference HOT 1
- Questions about the meaning of folder naming conventions in OpenProteinSet HOT 1
- ModuleNotFoundError: No module named 'attn_core_inplace_cuda' HOT 1
- RuntimeError: Error building extension 'evoformer_attn' HOT 1
- Multimer predicting a homomer HOT 2
- Docker container AttributeError: Did you mean: 'linear_a_p'? when trying to run multimer inference
- Unable to install OpenFold in Google Colaboratory
- MMseqs precomputing alignment too slow
- Unable to install OpenFold within Anaconda
- The flatten_roda.sh has problems.
- why target_feat is with shape (N,22), not same as alphafold2 paper (N,21)
- Uniprot hits for Open Protein Set
- installation error HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from openfold.