Coder Social home page Coder Social logo

Comments (5)

C-de-Furina avatar C-de-Furina commented on July 22, 2024 2

Well I understand what happens now. The main problem is that when you load from jax, ExponentialMovingAverage, the ema initiallized incorrectly. You can see in train_openfold.py line307, there is:

if args.resume_from_ckpt:
        if args.resume_model_weights_only:
            # Load the checkpoint
            if os.path.isdir(args.resume_from_ckpt):
                sd = zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint(
                    args.resume_from_ckpt)
            else:
                sd = torch.load(args.resume_from_ckpt)
            # Process the state dict
            if 'module' in sd:
                sd = {k[len('module.'):]: v for k, v in sd['module'].items()}
                import_openfold_weights_(model=model_module, state_dict=sd)
            elif 'state_dict' in sd:
                import_openfold_weights_(
                    model=model_module, state_dict=sd['state_dict'])
            else:
                # Loading from pre-trained model
                sd = {'model.'+k: v for k, v in sd.items()}
                import_openfold_weights_(model=model_module, state_dict=sd)
            logging.info("Successfully loaded model weights...")

Look, the variable 'state_dict' stored params used in ema. But if you check line 262:

def load_from_jax(self, jax_path):
        model_basename = os.path.splitext(
                os.path.basename(
                    os.path.normpath(jax_path)
                )
        )[0]
        model_version = "_".join(model_basename.split("_")[1:])
        logging.warning('load from version'+model_version)
        import_jax_weights_(
                self.model, jax_path, version=model_version
        )

There is nothing about state_dict, so ema can only use original parameters which come from AF model initialization, and this is not trained.

There should be some better methods but I'm not a professional programmer, so I fix this by:

1.Add an attribute to ema

def __init__(self, model: nn.Module, decay: float):
        """
        Args:
            model:
                A torch.nn.Module whose parameters are to be tracked
            decay:
                A value (usually close to 1.) by which updates are
                weighted as part of the above formula
        """
        super(ExponentialMovingAverage, self).__init__()
        clone_param = lambda t: t.clone().detach()
        self.params = tensor_tree_map(clone_param, model.state_dict())
        self.decay = decay
        self.device = next(model.parameters()).device
        self.load_from_jax = False

2.Add a method to ema

def repair_params_when_load_from_jax(self, state_dict: OrderedDict) -> None:
        for k in state_dict.keys():
            self.params[k] = state_dict[k].clone()

3.Run this method once when load_from_jax is true and set it to false

def update(self, model: torch.nn.Module) -> None:
        """
        Updates the stored parameters using the state dict of the provided
        module. The module should have the same structure as that used to
        initialize the ExponentialMovingAverage object.
        """
        if self.load_from_jax:
            self.load_from_jax = False
            self.repair_params_when_load_from_jax(model.state_dict())
        self._update_state_dict_(model.state_dict(), self.params)

4.When you load from jax, set load_from_jax to true

def load_from_jax(self, jax_path):
        model_basename = os.path.splitext(
                os.path.basename(
                    os.path.normpath(jax_path)
                )
        )[0]
        model_version = "_".join(model_basename.split("_")[1:])
        logging.warning('load from version'+model_version)
        import_jax_weights_(
                self.model, jax_path, version=model_version
        )
        self.ema.load_from_jax = True

from openfold.

C-de-Furina avatar C-de-Furina commented on July 22, 2024

Hi, have you solve this problem? I tried to print weight&bias from original AF parameters and ckpt after only one step train.
Almost all weights changed sharply and all bias are, even broken. I do not know why are these bias in e^-6 or e^-7 but they are definitely wrong.
image
image

from openfold.

jasonkim8652 avatar jasonkim8652 commented on July 22, 2024

@C-de-Furina Hi, I am suffering from same problem. I'm curious how did you make a ckpt format file from alphafold2 jax parameter. There is no script to convert jax parameter into openfold ckpt parameter in this repository. I also want to know how did you run the fine-tuning? Was it same in last issue you made? Thank you for your help!

from openfold.

C-de-Furina avatar C-de-Furina commented on July 22, 2024

@C-de-Furina 你好,我也遇到了同样的问题。我很好奇你是如何从 alphafold2 jax 参数制作出 ckpt 格式文件的。这个存储库中没有将 jax 参数转换为 openfold ckpt 参数的脚本。我也想知道你是如何运行微调的?你上次做的问题也是这样吗?谢谢你的帮助!

I'm still tring to find a way to create ckpt. Now I try to call "trainer.save_checkpoint("example.ckpt")" without any training but meet some troubles. Before this I created a ckpt by training from jax after only one step. So I can say, these parameters are broken at beginning. My training is in the same way as monomer except "--config_preset="model_5_multimer_v3""

Besides, you can try to use your trained ckpt to predict a homomer protein, like 8w7d, then you will see all chains almost completely overlaps, but this situation will not occur for heteromers. Thus, I think maybe there is some intrinsic defects in OF-multimer training.

from openfold.

jasonkim8652 avatar jasonkim8652 commented on July 22, 2024

Thank you for your kind explanation! I would try your solution and check whether problem solved.

from openfold.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.