Comments (9)
@xinyu1205 请问用GPU跑RAM模型需要多大显存呢?
3.8G
from recognize-anything.
对于Tag2Text,可以改成batch批量inference的代码,可以极大的增大inference效率。如果只要tagging的输出,可以在generate函数中只输出tagging就停止。
此外,我们新发布的the Recognize Anything Model (RAM)可以任意自定义识别的类别,可以进一步提升tagging inference效率,此功能会在未来发布。
from recognize-anything.
每次都加载了模型导致的,改用加载一次就好,MAC m2 cpu,网络图片,每张10秒到3秒之间
from recognize-anything.
对于Tag2Text,可以改成batch批量inference的代码,可以极大的增大inference效率。如果只要tagging的输出,可以在generate函数中只输出tagging就停止。 此外,我们新发布的the Recognize Anything Model (RAM)可以任意自定义识别的类别,可以进一步提升tagging inference效率,此功能会在未来发布。
已经发布了吗还是明天发布
from recognize-anything.
@xinyu1205 @onexuan 感谢2位的建议. 我写了个测试,现在慢的话处理一张图片4秒左右,快的话2~3秒
下面是代码,发出来看一下,是否还有提升的空间?
文件名:test.py
使用示例:
- python test.py --cache-path C:/Users/gaoyo/.cache/Tag2Text --image-dir C:/Users/gaoyo/Desktop/test1/
- python test.py --cache-path C:/Users/gaoyo/.cache/Tag2Text --images C:/Users/gaoyo/Desktop/test1/20170912_234158_1_14_70wf.jpeg C:/Users/gaoyo/Desktop/test1/20170918_192435_1_57_9yto.jpeg
# -*- coding: utf-8 -*-
'''
Author: gaoyong [email protected]
Date: 2023-06-08 10:51:43
LastEditors: gaoyong [email protected]
LastEditTime: 2023-06-08 11:07:57
FilePath: \Tag2Text\test.py
Description: 自动生成图片标签和内容描述
'''
import argparse
import json
import os
import time
import imghdr
import torch
import torchvision.transforms as transforms
from PIL import Image
from models.tag2text import tag2text_caption
def parse_args():
"""
This function parses command line arguments for a Tag2Text inference model.
:return: The function `parse_args()` is returning the parsed arguments from the command line using
the `argparse` module.
"""
parser = argparse.ArgumentParser(
description='Tag2Text inference for tagging and captioning')
parser.add_argument('--image-dir',
metavar='DIR',
help='path to directory containing input images',
default='')
parser.add_argument('--images',
metavar='IMAGE-LIST',
nargs='+',
help='list of space-separated image filenames',
default=[])
parser.add_argument('--pretrained',
metavar='DIR',
help='path to pretrained model',
default='D:/work/Tag2Text/pretrained/tag2text_swin_14m.pth')
parser.add_argument('--image-size',
default=384,
type=int,
metavar='N',
help='input image size (default: 448)')
parser.add_argument('--thre',
default=0.68,
type=float,
metavar='N',
help='threshold value')
parser.add_argument('--specified-tags',
default='None',
help='User input specified tags')
parser.add_argument('--cache-path',
default='None',
help='cache model file path')
return parser.parse_args()
def initialize_model(cache_path, pretrained, image_size, thre):
"""
This function initializes a Tag2Text model based on specified and identified tags.
:param cache_path: Cache model file path.
:param pretrained: Path to the pre-trained model.
:param image_size: Input image size.
:param thre: Threshold value for tagging.
:return: A pre-trained Tag2Text model.
"""
# delete some tags that may disturb captioning
# 127: "quarter"; 2961: "back", 3351: "two"; 3265: "three"; 3338: "four"; 3355: "five"; 3359: "one"
delete_tag_index = [127, 2961, 3351, 3265, 3338, 3355, 3359]
if os.path.exists(cache_path):
model = torch.load(cache_path)
else:
model = tag2text_caption(
pretrained=pretrained,
image_size=image_size,
vit='swin_b',
delete_tag_index=delete_tag_index
)
model.threshold = thre # threshold for tagging
model.eval()
torch.save(model, cache_path)
return model
def generate(model, image, input_tags=None):
"""
This function generates tags and captions for an input image.
:param model: The neural network model used for generating captions and predicting tags for an input
image.
:param image: The input image to generate tags and captions for.
:param input_tags: The input tags used as hints for the model to generate captions for the input image.
It is an optional parameter and can be set to None or left empty if no tag hint is required.
:return: A tuple of predicted tags, input tags, and generated captions.
"""
if input_tags in ('', 'none', 'None'):
input_tags = None
with torch.no_grad():
caption, tag_predict = model.generate(image,
tag_input=None,
max_length=50,
return_tag_predict=True)
if input_tags is None:
return tag_predict[0], None, caption[0]
input_tag_list = [input_tags.replace(',', ' | ')]
with torch.no_grad():
caption, input_tags = model.generate(image,
tag_input=input_tag_list,
max_length=50,
return_tag_predict=True)
return tag_predict[0], input_tags[0], caption[0]
def inference(images_dir, image_list, model, image_size, input_tags=None):
"""
This function takes a list of images or a directory containing images, a model, generates captions
for the images, and optionally takes a list of input tags to generate captions with those tags.
:param images_dir: A directory containing input images that the model will use to generate captions and
potentially predict tags for.
:param image_list: A list of input images the model will use to generate captions and potentially
predict tags for.
:param model: The neural network model used for generating captions and predicting tags for an input
image.
:param input_tags: The input tags are lists of strings that represent tags or sets of tags that are
used as hints for the model to generate captions for the given images. It is an optional parameter and
can be set to None or left empty if no tag hint is required, defaults to None.
:return: A list of dictionaries, each containing predicted tags, input tags (if provided), and
generated captions for a given input image.
"""
results = []
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(), normalize
])
if images_dir and os.path.isdir(images_dir):
for filename in os.listdir(images_dir):
filepath = os.path.join(images_dir, filename)
if not imghdr.what(filepath):
continue
img = Image.open(filepath).convert("RGB")
img_tensor = transform(img).unsqueeze(0).to(device)
res = generate(model, img_tensor, input_tags)
results.append({
"filepath": filepath,
"model_identified_tags": res[0],
"user_specified_tags": res[1],
"image_caption": res[2]
})
elif image_list and isinstance(image_list, list):
for img_path in image_list:
filepath = os.path.abspath(img_path)
if not os.path.isfile(filepath) or not imghdr.what(filepath):
continue
img = Image.open(filepath).convert("RGB")
img_tensor = transform(img).unsqueeze(0).to(device)
res = generate(model, img_tensor, input_tags)
results.append({
"filepath": img_path,
"model_identified_tags": res[0],
"user_specified_tags": res[1],
"image_caption": res[2]
})
return results
def main():
"""
This function loads a pre-trained image captioning model, processes input images in a directory,
and generates captions for each image based on specified and identified tags.
"""
start_time = time.time()
args = parse_args()
# check if a list of images is provided
images = args.images if args.images else None
# initialize the model
model = initialize_model(
args.cache_path, args.pretrained, args.image_size, args.thre)
# perform inference on images
data = inference(args.image_dir, images, model,
args.image_size, input_tags=None)
# output the results
results = {
"status": 0,
"message": 'ok',
"data": data
}
end_time = time.time()
elapsed_time = end_time - start_time
print(
f"Processed {len(results['data'])} images in {elapsed_time:.2f} seconds.")
print(json.dumps(results, ensure_ascii=False, indent=2))
# 使用示例:
# 1. python test.py --cache-path C:/Users/gaoyo/.cache/Tag2Text --image-dir C:/Users/gaoyo/Desktop/test1/
# 2. python test.py --cache-path C:/Users/gaoyo/.cache/Tag2Text --images C:/Users/gaoyo/Desktop/test1/20170912_234158_1_14_70wf.jpeg C:/Users/gaoyo/Desktop/test1/20170918_192435_1_57_9yto.jpeg
if __name__ == '__main__':
main()
from recognize-anything.
@xinyu1205 请问用GPU跑RAM模型需要多大显存呢?
from recognize-anything.
请问可以给我提供一个 cache-path文件的下载链接吗?我无法自动生成,代码给我的提示是网络链接异常,可能是因为我的网络问题
from recognize-anything.
每次都加载了模型导致的,改用加载一次就好,MAC m2 cpu,网络图片,每张10秒到3秒之间
对于Recognize Anything Model (RAM)模型,请问怎么改成加载一次。
from recognize-anything.
对于Tag2Text,可以改成batch批量inference的代码,可以极大的增大inference效率。如果只要tagging的输出,可以在generate函数中只输出tagging就停止。 此外,我们新发布的the Recognize Anything Model (RAM)可以任意自定义识别的类别,可以进一步提升tagging inference效率,此功能会在未来发布。
(RAM)可以任意自定义识别的类别,这个计划什么时期发布呢
from recognize-anything.
Related Issues (20)
- A question on embedding
- NameError: name '_C' is not defined HOT 1
- VisionTransformer undefined in ram.models.utils.py
- HuggingFace App is not working HOT 1
- Uncertain output results
- 【Bug】BertLayer should be used as a decoder model if cross attention is added
- finetuning on specific tag list
- How can I obtain the file ram_plus_swin_large_14m.pth? HOT 1
- how to form a ram_plus_tag_embedding_class_4585_des_51.pth for my own data. HOT 2
- Unable to proceed with command 'pip install -e .' HOT 2
- Can't load tokenizer for 'bert-base-uncased'
- tag_encoder and text_decoder HOT 1
- pip install error HOT 2
- Normalize image features while calculating the L1 loss
- i think it is the best to call it MAM(match-anything-model)
- CUDA out of memory error
- Pip Install Error
- Checkpoints for smaller versions of Swin
- Relax transformers dependency
- Tag2Text模型微调问题
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from recognize-anything.