Coder Social home page Coder Social logo

ennengyang / adamerging Goto Github PK

View Code? Open in Web Editor NEW
21.0 2.0 0.0 653 KB

AdaMerging: Adaptive Model Merging for Multi-Task Learning. ICLR, 2024.

Home Page: https://openreview.net/pdf?id=nZP6NgD3QY

Python 100.00%
model-editing model-fusion model-merging multi-task-learning test-time-adaptation

adamerging's Introduction

AdaMerging

A repository of 'AdaMerging: Adaptive Model Merging for Multi-Task Learning. ICLR, 2024.'.

Abstract

Multi-task learning (MTL) aims to empower a model to tackle multiple tasks simultaneously. A recent development known as task arithmetic has revealed that several models, each fine-tuned for distinct tasks, can be directly merged into a single model to execute MTL without necessitating a retraining process using the initial training data. Nevertheless, this direct addition of models often leads to a significant deterioration in the overall performance of the merged model. This decline occurs due to potential conflicts and intricate correlations among the multiple tasks. Consequently, the challenge emerges of how to merge pre-trained models more effectively without using their original training data. This paper introduces an innovative technique called Adaptive Model Merging (AdaMerging). This approach aims to autonomously learn the coefficients for model merging, either in a task-wise or layer-wise manner, without relying on the original training data. Specifically, our AdaMerging method operates as an automatic, unsupervised task arithmetic scheme. It leverages entropy minimization on unlabeled test samples from the multi-task setup as a surrogate objective function to iteratively refine the merging coefficients of the multiple models. Our experimental findings across eight tasks demonstrate the efficacy of the AdaMerging scheme we put forth. Compared to the current state-of-the-art (SOTA) task arithmetic merging scheme, AdaMerging showcases a remarkable 11% improvement in performance. Notably, AdaMerging also exhibits superior generalization capabilities when applied to unseen downstream tasks. Furthermore, it displays a significantly enhanced robustness to data distribution shifts that may occur during the testing phase.

AdaMerging

Citation

If you find our paper or this resource helpful, please consider cite:

@article{AdaMerging_ICLR_2024,
  title={AdaMerging: Adaptive Model Merging for Multi-Task Learning},
  author={Yang, Enneng and Wang, Zhenyi and Shen, Li and Liu, Shiwei and Guo, Guibing and Wang, Xingwei and Tao, Dacheng},
  journal={The Twelfth International Conference on Learning Representations},
  year={2024}
}

Thanks!

Datasets

Refer to dataset processing in the task_vectors.

Or you can download the processed data from Baidu Cloud disk.

Checkpoints

You can download the fine-tuned checkpoints from the task_vectors#checkpoints. The Google Drive folder is: task_vectors_checkpoints

Note: When using torch.load(xxx_checkpoint).state_dict() fails, you can try pickle.load(open(xxx_checkpoint, 'rb')).state_dict().

Code

Train

If you want to train AdaMerging, run this part of the code. If you want to load the trained merging coefficients directly, refer to the Eval section.

First enter the root directory of the source code.

cd root_path/src/

Run Task Atithmetic paper

python main_task_arithmetic.py

Run TIES-MERGING paper

python main_ties_merging.py

Run Task-wise AdaMerging (Ours)

python main_task_wise_adamerging.py

Run Task-wise AdaMerging++ (Ours)

python main_task_wise_adamergingpp.py

Run Layer-wise AdaMerging (Ours)

python main_layer_wise_adamerging.py

Run Layer-wise AdaMerging++ (Ours)

python main_layer_wise_adamergingpp.py

Eval

Alternatively, you can load our trained merge coefficients, which can be found in the merging_coefficient.py file. The general process is as follows:

# load
from merging_cofficient import get_merging_cofficients
ralpha = get_merging_cofficients(method_name, model_name)  
self.alpha = torch.Tensor(ralpha)

# wrap
if self.alpha.size()[0] == 1:# task-wise merging
    params = tuple(sum(tuple(pi * alphai for pi, alphai in zip(p, self.alpha[0].cpu()))) for j, p in enumerate(zip(*self.paramslist)))
else: # layer-wise merging
    params = tuple(sum(tuple(pi * alphai for pi, alphai in zip(p, self.alpha[j].cpu()))) for j, p in enumerate(zip(*self.paramslist)))

More details can be found in the following code: https://github.com/EnnengYang/RepresentationSurgery

Acknowledgement

Our implementation references the code below, thanks to them.

adamerging's People

Contributors

ennengyang avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

adamerging's Issues

About dataset preparation

Hi,

Congratulations on the acceptance of your work!

I am trying to reproduce the result, but run into some troubles preparing the datasets, especially for SUN397, RESISC45, EuroSAT and DTD. Although it is somewhat mentioned in this issue, it is not fully resolved. I wonder whether you can share your scripts for preprocessing and preparing the datasets? Thanks!

About merging cofficient

首先,感谢您分享了这个令人兴奋的项目和研究成果。我对AdaMerging方法的实现特别感兴趣,特别是关于如何确定不同任务权重(ralpha)的部分。在阅读代码和论文后,我对这个过程有一些疑问,希望能得到您的帮助。
例如,在代码中,任务向量权重被直接指定为:ralpha = [[1.0000, 0.2202, 0.1413, 0.2826, 0.3284, 0.2841, 0.4003, 0.1978, 0.1692]]
这些权重值是基于某种优化过程(训练)确定的,还是通过实验调整得到的最佳值?
如果这些权重是通过自动化方法得到的,能否分享一下具体的过程和步骤?这对于理解模型的整体性能提升机制非常重要。
是否有推荐的方法或实践,用于确定在其他数据集或任务上合并模型时的类似权重?
再次感谢您的工作和分享。

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.