Coder Social home page Coder Social logo

researchmm / tasn Goto Github PK

View Code? Open in Web Editor NEW
218.0 218.0 40.0 21.67 MB

Trilinear Attention Sampling Network for Fine-grained Image Recognition

CMake 0.47% Groovy 0.11% Makefile 0.43% R 1.86% C++ 33.06% Python 33.74% Java 0.88% C 0.92% Shell 1.86% Dockerfile 0.20% PowerShell 0.03% Clojure 2.10% HTML 0.23% CSS 0.14% Jupyter Notebook 8.97% Batchfile 0.07% Julia 2.16% MATLAB 0.18% Perl 8.15% Cuda 4.43%

tasn's People

Contributors

heliang-zheng avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

tasn's Issues

How to make the inference?

Hi,

I have trained the model follow your README.md in tasn-mxnet.
I would like to know how can I test my built model?
I am new with mxnet.
Please show me the script if possible.
Thank you!

Best Regards,
Terry

Using Attention-based Sampler (AttSampler) in TASN without the need of rebuilding MXNet

Hi, there.
I wrote a project in order to use attention-based sampler of TASN without the need of rebuilding MXNet.
The link of this project is https://github.com/wkcn/AttentionSampler

It is available for MXNet and PyTorch.

The result (default setting):

INFO:root:Epoch[204] Train-att_net_accuracy=1.000000
INFO:root:Epoch[204] Train-part_net_accuracy=0.979167
INFO:root:Epoch[204] Train-master_net_accuracy=0.989583
INFO:root:Epoch[204] Train-part_net_aux_accuracy=0.979167
INFO:root:Epoch[204] Train-master_net_aux_accuracy=0.989583                         
INFO:root:Epoch[204] Train-distillation_loss=4.280940                               
INFO:root:Epoch[204] Time cost=20.882
INFO:root:Epoch[204] Validation-att_net_accuracy=0.806771
INFO:root:Epoch[204] Validation-part_net_accuracy=0.849132
INFO:root:Epoch[204] Validation-master_net_accuracy=0.856944
INFO:root:Epoch[204] Validation-part_net_aux_accuracy=0.870486
INFO:root:Epoch[204] Validation-master_net_aux_accuracy=0.867361
INFO:root:Epoch[204] Validation-distillation_loss=3.713491



INFO:root:Epoch[299] Train-att_net_accuracy=1.000000
INFO:root:Epoch[299] Train-part_net_accuracy=0.984375
INFO:root:Epoch[299] Train-master_net_accuracy=0.984375
INFO:root:Epoch[299] Train-part_net_aux_accuracy=1.000000
INFO:root:Epoch[299] Train-master_net_aux_accuracy=1.000000
INFO:root:Epoch[299] Train-distillation_loss=4.100089
INFO:root:Epoch[299] Time cost=20.978
INFO:root:Saved checkpoint to "./model/tasn-0300.params"
INFO:root:Epoch[299] Validation-att_net_accuracy=0.804986
INFO:root:Epoch[299] Validation-part_net_accuracy=0.856728
INFO:root:Epoch[299] Validation-master_net_accuracy=0.860485
INFO:root:Epoch[299] Validation-part_net_aux_accuracy=0.864754
INFO:root:Epoch[299] Validation-master_net_aux_accuracy=0.869023
INFO:root:Epoch[299] Validation-distillation_loss=3.620270

Hope that it will be helpful for you!

Questions about detailed attention map during test phase

Hi, thanks for your contribution!
In Section 4.2 in the paper, you mentioned that

we randomly select a channel of attention maps in each iteration in training stage, and conduct average pooling over attention maps for testing

If we do average pooling during testing, then there is no difference between structure attention map and detail attention map. So can you explain the detailed operations in testing process?

The Multi_Accuracy metric is not compatible with mxnet 1.6.0

Hi,

I tried to train the network by just changing the batchsize and gpus in the default setting. And I get the following error, which occurs after the finishing of the first batch.

[09:20:05] src/nnvm/legacy_json_util.cc:209: Loading symbol saved by previous version v0.8.0. Attempting to upgrade... [09:20:05] src/nnvm/legacy_json_util.cc:217: Symbol successfully upgraded! [09:20:05] src/nnvm/legacy_json_util.cc:209: Loading symbol saved by previous version v0.8.0. Attempting to upgrade... [09:20:05] src/nnvm/legacy_json_util.cc:217: Symbol successfully upgraded! INFO:root:start with arguments Namespace(batch_size=2, benchmark=0, data_nthreads=128, disp_batches=20, dtype='float32', gpus='0', image_shape='3,512,512', kv_store='device', load_epoch=None, lr=0.1, lr_factor=0.1, lr_step_epochs='100,200', max_random_aspect_ratio=0, max_random_h=0, max_random_l=0, max_random_rotate_angle=0, max_random_s=0, max_random_scale=1, max_random_shear_ratio=0, min_random_scale=1, model_prefix='./model/tasn', mom=0, monitor=0, network=None, num_classes=200, num_epochs=300, num_examples=5994, num_layers=None, optimizer='sgd', pad_size=0, random_crop=1, random_mirror=1, rgb_mean='123.68,116.779,103.939', test_io=0, top_k=5, wd=0) [09:20:05] src/io/iter_image_recordio_2.cc:178: ImageRecordIOParser2: ./data/cub/train.rec, use 4 threads for decoding.. [09:20:08] src/io/iter_image_recordio_2.cc:178: ImageRecordIOParser2: ./data/cub/val.rec, use 4 threads for decoding.. learning rate from ``lr_scheduler`` has been overwritten by ``learning_rate`` in optimizer. INFO:root:Epoch[0] Batch [0-20] Speed: 33.71 samples/sec att_net_accuracy=0.000000 part_net_accuracy=0.023810 master_net_accuracy=0.023810 part_net_aux_accuracy=0.023810 master_net_aux_accuracy=0.023810 distillation_loss=5.296982 Traceback (most recent call last): File "train.py", line 57, in <module> eval_metric = evaluate.Multi_Accuracy(num=6)) File "/home/ysu/project/attention_net/tasn/tasn-mxnet/example/tasn/common/fit.py", line 195, in fit monitor = monitor File "/home/ysu/mxnet_attention/lib/python3.5/site-packages/mxnet/module/base_module.py", line 533, in fit self.update_metric(eval_metric, data_batch.label) File "/home/ysu/mxnet_attention/lib/python3.5/site-packages/mxnet/module/module.py", line 775, in update_metric self._exec_group.update_metric(eval_metric, labels, pre_sliced) File "/home/ysu/mxnet_attention/lib/python3.5/site-packages/mxnet/module/executor_group.py", line 640, in update_metric eval_metric.update_dict(labels_, preds) File "/home/ysu/mxnet_attention/lib/python3.5/site-packages/mxnet/metric.py", line 133, in update_dict self.update(label, pred) File "/home/ysu/project/attention_net/tasn/tasn-mxnet/example/tasn/common/evaluate.py", line 32, in update self.sum_metric[i] += (pred_label.flat == label.flat).sum() TypeError: 'float' object is not subscriptable

The reason is that, in mxnet1.6.0, the EvalMetric class has not only num_inst , sum_metric, but also global_num_inst, global_sum_metric.

And in the batch_end_callback function (here is Speedometer), it will execute reset_local() function to reset num_inst , sum_metric, rather than reset() function as in the old version of mxnet.

However, you don't have the implementation of reset_local() in your Multi_Accuracy class. So the sum_metric will be reset as 0.0 using the reset_local() function in the EvalMetric class.

A quick solution could be, set the auto_reset argument in Speedometer as False.

questions about Structure_Att and Detail_Att

Hi, thanks for your job!
There are two points in the code I cannot understand:
1.line 133 in tasn-mxnet/example/tasn/model.py: mx.nd.batch_dot(b.reshape((n,c,1)), f.reshape((n,c,w*w)), True, False).reshape((n,1,w,w))
b is sorted but f not, i don't think their attention channels can match one by one.

2.the channels in both Structure_Att and Detail_Att are filtered to make sure that each channel is unique,but why you think the channels with the same sum value can be represented by one of these channels?

looking forward to your reply.

ImportError att_grid_generator_cuda error

Hi authors, it's a nice work. I follow the instruction and correct setup the pytorch code. But when i import att_grid_generator_cuda, problem occurred. "ImportError: anaconda3/lib/python3.6/site-packages/att_grid_generator-0.0.0-py3.6-linux-x86_64.egg/att_grid_generator_cuda.cpython-36m-x86_64-linux-gnu.so: undefined symbol: _ZN3c1011CPUTensorIdEv" Can you help me fix this problem? Thanks.

about AttSampler

Hi ,I am using pytorch to implement your code about AttSampler. I don't understand your code in src/operator/contrib/att_sampler-inl.h
src/operator/contrib/att_sampler.cc
src/operator/contrib/att_sampler.cu
Can you elaborate on this part? Thank you.

pre-build mxnet is useless

I try to use your pre-build MXnet and install, but the same error.
no AttSampler method, and i can't find the c files in contrib

RuntimeError: simple_bind error. Arguments

您好~最近我正在尝试你们的代码。然而在实验时出现了以下错误:
`
(python37) E:\QKYO\DemoForThesis\tasn\tasn-mxnet\example\tasn>python train.py --gpus 0,1,2,3,4,5,6 --model-prefix ./model/tasn --data-nthreads 128 --batch-size 96 --num-classes 200 --num-examples 5994
[17:30:20] E:\QKYO\DemoForThesis\tasn\tasn-mxnet\src\nnvm\legacy_json_util.cc:209: Loading symbol saved by previous version v0.8.0. Attempting to upgrade...
[17:30:20] E:\QKYO\DemoForThesis\tasn\tasn-mxnet\src\nnvm\legacy_json_util.cc:217: Symbol successfully upgraded!
[17:30:20] E:\QKYO\DemoForThesis\tasn\tasn-mxnet\src\nnvm\legacy_json_util.cc:209: Loading symbol saved by previous version v0.8.0. Attempting to upgrade...
[17:30:20] E:\QKYO\DemoForThesis\tasn\tasn-mxnet\src\nnvm\legacy_json_util.cc:217: Symbol successfully upgraded!
INFO:root:start with arguments Namespace(batch_size=96, benchmark=0, data_nthreads=128, disp_batches=20, dtype='float32', gpus='0,1,2,3,4,5,6', image_shape='3,512,512', kv_store='device', load_epoch=None, lr
=0.1, lr_factor=0.1, lr_step_epochs='100,200', max_random_aspect_ratio=0, max_random_h=0, max_random_l=0, max_random_rotate_angle=0, max_random_s=0, max_random_scale=1, max_random_shear_ratio=0, min_random_s
cale=1, model_prefix='./model/tasn', mom=0, monitor=0, network=None, num_classes=200, num_epochs=300, num_examples=5994, num_layers=None, optimizer='sgd', pad_size=0, random_crop=1, random_mirror=1, rgb_mean
='123.68,116.779,103.939', test_io=0, top_k=5, wd=0)
[17:30:20] E:\QKYO\DemoForThesis\tasn\tasn-mxnet\src\io\iter_image_recordio_2.cc:172: ImageRecordIOParser2: ./data/cub/train.rec, use 3 threads for decoding..
[17:30:20] E:\QKYO\DemoForThesis\tasn\tasn-mxnet\src\io\iter_image_recordio_2.cc:172: ImageRecordIOParser2: ./data/cub/val.rec, use 3 threads for decoding..
Traceback (most recent call last):
File "E:\Anaconda3\envs\python37\lib\site-packages\mxnet-1.3.1-py3.7.egg\mxnet\symbol\symbol.py", line 1523, in simple_bind
ctypes.byref(exe_handle)))
File "E:\Anaconda3\envs\python37\lib\site-packages\mxnet-1.3.1-py3.7.egg\mxnet\base.py", line 252, in check_call
raise MXNetError(py_str(_LIB.MXGetLastError()))
mxnet.base.MXNetError: Error in operator part_net: Shape inconsistent, Provided = [14], inferred shape=[13]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "train.py", line 59, in
eval_metric = evaluate.Multi_Accuracy(num=6))
File "E:\QKYO\DemoForThesis\tasn\tasn-mxnet\example\tasn\common\fit.py", line 195, in fit
monitor = monitor)
File "E:\Anaconda3\envs\python37\lib\site-packages\mxnet-1.3.1-py3.7.egg\mxnet\module\base_module.py", line 499, in fit
for_training=True, force_rebind=force_rebind)
File "E:\Anaconda3\envs\python37\lib\site-packages\mxnet-1.3.1-py3.7.egg\mxnet\module\module.py", line 429, in bind
state_names=self._state_names)
File "E:\Anaconda3\envs\python37\lib\site-packages\mxnet-1.3.1-py3.7.egg\mxnet\module\executor_group.py", line 279, in init
self.bind_exec(data_shapes, label_shapes, shared_group)
File "E:\Anaconda3\envs\python37\lib\site-packages\mxnet-1.3.1-py3.7.egg\mxnet\module\executor_group.py", line 375, in bind_exec
shared_group))
File "E:\Anaconda3\envs\python37\lib\site-packages\mxnet-1.3.1-py3.7.egg\mxnet\module\executor_group.py", line 662, in _bind_ith_exec
shared_buffer=shared_data_arrays, **input_shapes)
File "E:\Anaconda3\envs\python37\lib\site-packages\mxnet-1.3.1-py3.7.egg\mxnet\symbol\symbol.py", line 1529, in simple_bind
raise RuntimeError(error_msg)
RuntimeError: simple_bind error. Arguments:
data: (14, 3, 512, 512)
att_net_label: (14,)
part_net_label: (14,)
master_net_label: (14,)
part_net_aux_label: (14,)
master_net_aux_label: (14,)
Error in operator part_net: Shape inconsistent, Provided = [14], inferred shape=[13]
`
我在这方面是萌新……想请问您是否知道出现这种报错可能的原因?
另外,对于您提到的Added Files我不是很能理解。可以请教是用来干什么的吗?谢谢

output not feed into the softmax

Hello, I have noticed that in main_tasn.py, the outputs of fc directly are used to compute loss without feeding into the softmax, Is there any reason or just a mistake?
image

image

Successfully run in Windows10 !

``As TASN uses "att_grid_generator_cuda" in its code, we should first run "python setup.py build/install".

However, in Windows10, we may encounter something wrong. And here is my solution, hope to be helpful. (It indeed takes my much time.)

First, if you use VS in Chinese, you need to first change line 299 (around) in "cpp_extension.py" like this:

 try:
        if sys.platform.startswith('linux'):
            minimum_required_version = MINIMUM_GCC_VERSION
            versionstr = subprocess.check_output([compiler, '-dumpfullversion', '-dumpversion'])
            version = versionstr.decode().strip().split('.')
        else:
            # minimum_required_version = MINIMUM_MSVC_VERSION
            # compiler_info = subprocess.check_output(compiler, stderr=subprocess.STDOUT)
            # match = re.search(r'(\d+)\.(\d+)\.(\d+)', compiler_info.decode().strip())
            # version = (0, 0, 0) if match is None else match.groups()
            print("________________windows operation system ______________________")
            minimum_required_version = MINIMUM_MSVC_VERSION
            compiler_info = subprocess.check_output(compiler, stderr=subprocess.STDOUT)
            print("________________compiler info:", compiler_info)
            # match = re.search(r'(\d+)\.(\d+)\.(\d+)', compiler_info.decode().strip())
            match = re.search(r'(\d+)\.(\d+)\.(\d+)', compiler_info.decode(' gbk').strip())
            version = (0, 0, 0) if match is None else match.groups()

The code commented out above is the original ones. (This step is to avoid the format problem: utf-8 and gbk.)

Next, we should change the codes in "tasn-pytorch\src".

For "att_grid_generator_cuda.cpp", you should change all "AT_CHECK" into "TORCH_CHECK".

And for "att_grid_generator_cuda_kernel.cu", you should remove all the "const" of the parameters in fuction "attgridgen_gpu". Just like this:

void attgridgen_gpu(at::Tensor attx, at::Tensor atty,
    at::Tensor map_xi, at::Tensor map_yi,
    at::Tensor index_x, at::Tensor index_y,
    int batch_size, int att_size, int out_size, 
    float threshold, int iters)
{...}

Then, the call "python setup.py build" and "python setup.py install" can run successfully.

Expected date of PyTorch implementation?

First of all, thank you for your awesome idea! I really like the neat way you combine attention and part recognition. Congratulations on your inspiring work!

I followed the README docs and successfully reproduced the experiment result in your paper. However, the fact that the code works on a customized version of MXNet and the computation being done with symbol instead of gluon is not quite favorable. I noticed at the very end of your README file that you are planning to release a PyTorch version. Is there an expected data for that?

Got error after running "sudo bash train.sh" command

First,
Hi, I have followed all the necessary steps as mentioned in the readme section. I have install mxnet, nccl, cudnn using "sudo bash install.sh". Upto here, everything works fine, all the necessary files have downloaded and installed. On continuing, executing the final command i.e "sudo bash train.sh", I got the following error, please guide me in this.
Traceback (most recent call last):
File "train.py", line 13, in
from common import fit, evaluate
File "/content/gdrive/My Drive/MobulaOP/tasn/tasn-mxnet/example/tasn/common/fit.py", line 1, in
import mxnet as mx
File "/content/gdrive/My Drive/MobulaOP/tasn/tasn-mxnet/example/tasn/mxnet_python/python/mxnet/init.py", line 24, in
from .context import Context, current_context, cpu, gpu, cpu_pinned
File "/content/gdrive/My Drive/MobulaOP/tasn/tasn-mxnet/example/tasn/mxnet_python/python/mxnet/context.py", line 24, in
from .base import classproperty, with_metaclass, _MXClassPropertyMetaClass
File "/content/gdrive/My Drive/MobulaOP/tasn/tasn-mxnet/example/tasn/mxnet_python/python/mxnet/base.py", line 213, in
_LIB = _load_lib()
File "/content/gdrive/My Drive/MobulaOP/tasn/tasn-mxnet/example/tasn/mxnet_python/python/mxnet/base.py", line 204, in _load_lib
lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_LOCAL)
File "/usr/lib/python3.6/ctypes/init.py", line 348, in init
self._handle = _dlopen(self._name, mode)
OSError: libcudart.so.8.0: cannot open shared object file: No such file or directory

Second,
I have followed similar process, except I have commented mxnet installation process from "install.sh" files, with the target of removing above mentioned error. So, I have installed mxnet using command "pip install mxnet-cu101". On executing command "sudo bash train.sh", above error was gone but I have been with next error which is shown below:

Error-Start
[10:25:25] src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:97: Running performance tests to find the best convolution algorithm, this can take a while... (set the environment variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)
sh: 1: nvcc: not found
sh: 1: nvcc: not found
Error in CustomOp.forward: Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-packages/mxnet/operator.py", line 1005, in forward_entry
aux=tensors[4])
File "/content/gdrive/My Drive/MobulaOP/mobula/glue/mxnet_glue.py", line 109, in forward
out = self._forward(*in_data)
File "./AttentionSampler/attention_sampler/attention_sampler.py", line 60, in forward
mobula.func.map_step(N, attxi, index_y, stepx, att_size, out_size)
File "/content/gdrive/My Drive/MobulaOP/mobula/func.py", line 264, in call
using_async=using_async)
File "/content/gdrive/My Drive/MobulaOP/mobula/func.py", line 145, in call
func = self.loader(self, arg_types, ctx, **self.loader_kwargs)
File "/content/gdrive/My Drive/MobulaOP/mobula/op/loader.py", line 499, in init
_build_lib(cpp_fname, code_buffer, ctx, dll_fname)
File "/content/gdrive/My Drive/MobulaOP/mobula/op/loader.py", line 237, in _build_lib
source_to_so_ctx(build_path, srcs, target_name, ctx)
File "/content/gdrive/My Drive/MobulaOP/mobula/building/build.py", line 167, in source_to_so_ctx
buildin_cpp, buildin_o), compiler, cflags)
File "/content/gdrive/My Drive/MobulaOP/mobula/building/build.py", line 41, in source_to_o
run_command_parallel(commands)
File "/content/gdrive/My Drive/MobulaOP/mobula/building/build_utils.py", line 97, in run_command_parallel
raise RuntimeError(info)
RuntimeError: Error, terminated :-(

Traceback (most recent call last):
File "train.py", line 57, in
eval_metric = evaluate.Multi_Accuracy(num=6))
File "/content/gdrive/My Drive/MobulaOP/tasn/tasn-mxnet/example/tasn/common/fit.py", line 195, in fit
monitor = monitor)
File "/usr/local/lib/python3.6/dist-packages/mxnet/module/base_module.py", line 533, in fit
self.update_metric(eval_metric, data_batch.label)
File "/usr/local/lib/python3.6/dist-packages/mxnet/module/module.py", line 775, in update_metric
self.exec_group.update_metric(eval_metric, labels, pre_sliced)
File "/usr/local/lib/python3.6/dist-packages/mxnet/module/executor_group.py", line 648, in update_metric
eval_metric.update_dict(labels, preds)
File "/usr/local/lib/python3.6/dist-packages/mxnet/metric.py", line 132, in update_dict
self.update(label, pred)
File "/content/gdrive/My Drive/MobulaOP/tasn/tasn-mxnet/example/tasn/common/evaluate.py", line 23, in update
pred_label = mx.nd.argmax_channel(preds[i]).asnumpy()
File "/usr/local/lib/python3.6/dist-packages/mxnet/ndarray/ndarray.py", line 2566, in asnumpy
ctypes.c_size_t(data.size)))
File "/usr/local/lib/python3.6/dist-packages/mxnet/base.py", line 246, in check_call
raise get_last_ffi_error()
mxnet.base.MXNetError: Traceback (most recent call last):
File "src/operator/custom/custom.cc", line 346
MXNetError: Check failed: reinterpret_cast( params.info->callbacks[kCustomOpForward])( ptrs.size(), const_cast<void**>(ptrs.data()), const_cast<int*>(tags.data()), reinterpret_cast<const int*>(req.data()), static_cast(ctx.is_train), params.info->contexts[kCustomOpForward]):
**Error-End

Please, guide me in this, how I can detach these errors and run TASN successfully

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.