Comments (2)
Based on my script here it should be quite out-of-the-box to compile and run it, and I do get about 4x speed up:
import os
from contextlib import contextmanager
from functools import partial
from time import perf_counter
from typing import Optional
@contextmanager
def catchtime(s) -> float:
start = perf_counter()
yield lambda: perf_counter() - start
print(f'Time of {s=}: {perf_counter() - start:.3f} seconds')
import requests
import torch
from PIL import Image
from tqdm import tqdm
from transformers import LlavaNextForConditionalGeneration, LlavaNextProcessor, StaticCache
# noinspection PyProtectedMember
torch._inductor.config.coordinate_descent_tuning = True
# noinspection PyProtectedMember
torch._inductor.config.triton.unique_kernel_names = True
# noinspection PyProtectedMember
torch._inductor.config.fx_graph_cache = True
os.environ["TOKENIZERS_PARALLELISM"] = "true"
MODEL_NAME = 'models/llava-v1.6-mistral-7b-hf'
def mem(): return torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024, torch.cuda.memory_allocated() / 1024 / 1024 / 1024
assert torch.cuda.is_available()
device = "cuda"
def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
"""Copied from https://github.com/pytorch-labs/gpt-fast/blob/f6973170327003c6b1ce7edb5c015b4fa0097e6d/generate.py#L40C1-L42C82"""
q = torch.empty_like(probs_sort).exponential_(1)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
"""Copied from https://github.com/pytorch-labs/gpt-fast/blob/f6973170327003c6b1ce7edb5c015b4fa0097e6d/generate.py#L44C1-L52C17"""
logits = logits / max(temperature, 1e-5)
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
pivot = v.select(-1, -1).unsqueeze(-1)
logits = torch.where(logits < pivot, -float("Inf"), logits)
probs = torch.nn.functional.softmax(logits, dim=-1)
return probs
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
"""Copied from https://github.com/pytorch-labs/gpt-fast/blob/f6973170327003c6b1ce7edb5c015b4fa0097e6d/generate.py#L54C1-L57C27"""
probs = logits_to_probs(logits[:, -1], temperature, top_k)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs
def decode_one_tokens(model, cur_token, cache_position):
"""Copied from https://github.com/pytorch-labs/gpt-fast/blob/f6973170327003c6b1ce7edb5c015b4fa0097e6d/generate.py#L64C1-L68C45"""
logits = model(cur_token, cache_position=cache_position, return_dict=False, use_cache=True)[0]
new_token = sample(logits, temperature=0)[0]
return new_token
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)
def gen(model, inputs, iters=100):
print(inputs['input_ids'].shape, inputs['image_sizes'])
generated_ids = torch.zeros((1, iters), dtype=torch.int, device=device)
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=True, enable_math=False):
output = model(**inputs)
seq_len, logits = output.loss, output.logits
cache_position = torch.tensor([seq_len], device=device)
input_id = sample(logits, temperature=0)[0]
generated_ids[:, 0] = input_id[:, 0]
gen_pos = torch.tensor([1], device=device)
print('post-1st ', mem())
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
for i in tqdm(range(iters - 1)):
input_id = decode_one_tokens(model.language_model, input_id.clone(), cache_position)
generated_ids.index_copy_(1, gen_pos, input_id)
cache_position += 1
gen_pos += 1
print('post-last ', mem())
return generated_ids
with torch.inference_mode():
processor = LlavaNextProcessor.from_pretrained(MODEL_NAME)
url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "[INST] <image>\nWhat is shown in this image? [/INST]"
torch.cuda.memory._record_memory_history()
print('pre-model', mem())
model = LlavaNextForConditionalGeneration.from_pretrained(MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float16).to(device)
print('pre-cache', mem())
static_cache = partial(StaticCache, dtype=torch.float16)
model.language_model._setup_cache(static_cache, max_batch_size=1, max_cache_len=4096)
print('pre-comp ', mem())
model.language_model.compile(mode='reduce-overhead', fullgraph=True)
model.vision_tower.compile(mode='reduce-overhead', fullgraph=True)
print('pre-proc ', mem())
inputs = processor(prompt, image, return_tensors="pt").to(device)
print('pre-gen1 ', mem())
with catchtime('first compile gen:'):
out = gen(model, inputs, iters=10)
# print(out)
print(processor.decode(out[0], skip_special_tokens=True))
print('pre-gen2 ', mem())
with catchtime('second compile gen:'):
out = gen(model, inputs, iters=100)
# print(out)
print(processor.decode(out[0], skip_special_tokens=True))
# torch.cuda.memory._dump_snapshot("snapshot_full.pickle")
But the issue here is compiling a full fp16 requires more than 16GB vram which is more than what I have for production.
from gptfast.
Hey, apologies for the late response! I will look into this and get back to you soon :)
from gptfast.
Related Issues (8)
- Update PyTorch to 2.2 HOT 2
- Help to understand HOT 6
- ERROR: No matching distribution found for triton==2.1.0 HOT 2
- GPTFast 0.2.1: function argmax_variation() is not used
- Model Config settings for Llama-based architectures HOT 8
- Doesn't work on kaggle notebooks - ValueError: Unable to compare versions for numpy>=1.17: need=1.17 found=None. This is unusual. Consider reinstalling numpy.
- Help to run GPTFast on Mixtral-8x7B-Instruct-v0.1 HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from gptfast.