Coder Social home page Coder Social logo

ucad's Introduction

UCAD in AAAI-2024

Offical code for Unsupervised Continual Anomaly Detection with Contrastively-learned Prompt in AAAI-2024 [paper link]

Jiaqi Liu*, Kai Wu*, Qiang Nie, Ying Chen, Bin-Bin Gao,Yong Liu, Jinbao Wang, Chengjie Wang, Feng Zheng†

Introduction

UCAD is a novel Unsupervised Continual AD framework and we augment it with SAM.

environment

basic

python>=3.8, torch>=1.12, CUDA>=11.3, timm==0.6.7

install SAM:

pip install git+https://github.com/facebookresearch/segment-anything.git

or clone the repository locally and install with

git clone [email protected]:facebookresearch/segment-anything.git
cd segment-anything; pip install -e .

prepare for training

rename the dataset dir to 'mvtec2d' and create sam senmantic dir (processed mvtec2d-sam-b.zip is provided in repository)

cp -r $mvtec_origin_data_path('./mvtec2d') $mvtec_data_path('./mvtec2d-sam-b')
cd UCAD/segment_anything
python3 dataset_sam.py --sam_type 'vit_b' --sam_checkpoint $your_sam_path --data_path $mvtec_data_path

training and evaluation

environment prepare:

datapath=/hhd3/m3lab/data/mvtec2d datasets=('bottle' 'cable' 'capsule' 'carpet' 'grid' 'hazelnut' 'leather' 'metal_nut' 'pill' 'screw' 'tile' 'toothbrush' 'transistor' 'wood' 'zipper')
dataset_flags=($(for dataset in "${datasets[@]}"; do echo '-d '$dataset; done))

training:

CUDA_VISIBLE_DEVICES=0 python3 run_ucad.py --gpu 0 --seed 0 --memory_size 196 --log_group IM224_UCAD_L5_P01_D1024_M196 --save_segmentation_images --log_project MVTecAD_Results results ucad -b wideresnet50 -le layer2 -le layer3 --faiss_on_gpu --pretrain_embed_dimension 1024 --target_embed_dimension 1024 --anomaly_scorer_num_nn 1 --patchsize 1 sampler -p 0.1 approx_greedy_coreset dataset --resize 224 --imagesize 224 "${dataset_flags[@]}" mvtec $datapath

Parameter

Main contents are contained in three files: ./patchcore/patchcore.py, ./patchcore/vision_transformer.py, and ./run_ucad.py. Whether to save the image, the image size, and the memory size can all be modified in the above training command. Parameters about prompt are in ./patchcore/patchcore.py line 99.

The inference involving a query process, it's slow, and I've commented it out in the code (./run_ucad.py lines 408-509). Training will directly provide the final results, and the inference process merely repeats this step. The final output will consist of two parts, with the lower metrics representing the final results, and the difference between them and the higher metrics results is denoted as FM.

Acknowledgments

This work is supported by the National Key R&D Program of China (Grant NO. 2022YFF1202903) and the National Natural Science Foundation of China (Grant NO. 62122035).

Our benchmark is built on PatchCore and DualPrompt_Pytorch, thanks their extraordinary works!

ucad's People

Contributors

shirowalker avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

ucad's Issues

cv2.error: !ssize.empty() in function 'resize'

Thank you for your great work, I have a question.
When I try the vision_transformer.py code, I get the following error. Can you tell me how to solve it?

for i in range(x.shape[0]):
               if('mvtec2d' in image_path[i]):
                   sam_score = cv2.imread(image_path[i].replace('mvtec2d','mvtec2d-sam-b'))
               elif('visa' in image_path[i]):
                   sam_score = cv2.imread(image_path[i].replace('visa','visa-sam-b'))
               labels[i] = torch.from_numpy(cv2.resize(sam_score,(14,14))[:,:,0].flatten()).cuda()
           res['loss'] = torch.tensor(0).float().cuda()
           # loss for sam
           for k in range(len(res['seg_feat'])):
               res['loss'] += self.contrastive_loss(res['seg_feat'][k], labels, temperature=0.5)
       else:
           pass
       
       return res

image

File "/home/ewkim/Desktop/wan/UCAD/UCAD-main/patchcore/vision_transformer.py", line 713, in forward_head
labels[i] = torch.from_numpy(cv2.resize(sam_score,(14,14))[:,0,0].flatten()).cuda()
cv2.error: OpenCV(4.6.0) /io/opencv/modules/imgproc/src/resize.cpp:4052: error: (-215:Assertion failed) !ssize.empty() in function 'resize'

莫名报错

作者大大,能否帮忙看下这个报错呢?

image

The result of training directly with SAM's segmented mask map is not good

Hello author, I use your code for training and drawing segmentation diagram. Instead of using the original diagram for training, you use the mask diagram divided by SAM. As a result, the resulting diagram is not good, which makes me very confused。
image
orin_img:
image
I would like to ask what should be done to get the results in the paper?

Can't recognize do echo

1713775437544

When we run the code and configure the environment variables, we encounter the problem that do echo cannot be recognized. Is there a specific solution here? Thank you.

sam_vit_b_01ec64.pth文件

parser.add_argument('--sam_checkpoint', default='/hhd3/ljq/checkpoints/sam_vit_b_01ec64.pth') # checkpoint path for sam
对于代码里面的这个sam_vit_b_01ec64.pth文件,是不是没有给出

关于模型的一些疑惑

您好,我认为这篇工作非常有意义!但是关于模型我有一些疑虑,从代码来看,本文相当于是为每一个类学习一个特征空间,这其实和在非持续学习设定下每一个单独进行学习是类似的,可是本文的指标要比PatchCore原文低很多(按理来说在这样相似的设定下指标应该差不多)。可能是我对代码和论文的理解有一些偏差,麻烦您可以解答一下我的疑虑,感谢!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.