Coder Social home page Coder Social logo

Comments (2)

aliencaocao avatar aliencaocao commented on August 15, 2024

Based on my script here it should be quite out-of-the-box to compile and run it, and I do get about 4x speed up:

import os
from contextlib import contextmanager
from functools import partial
from time import perf_counter
from typing import Optional


@contextmanager
def catchtime(s) -> float:
    start = perf_counter()
    yield lambda: perf_counter() - start
    print(f'Time of {s=}: {perf_counter() - start:.3f} seconds')


import requests
import torch
from PIL import Image
from tqdm import tqdm

from transformers import LlavaNextForConditionalGeneration, LlavaNextProcessor, StaticCache

# noinspection PyProtectedMember
torch._inductor.config.coordinate_descent_tuning = True
# noinspection PyProtectedMember
torch._inductor.config.triton.unique_kernel_names = True
# noinspection PyProtectedMember
torch._inductor.config.fx_graph_cache = True

os.environ["TOKENIZERS_PARALLELISM"] = "true"

MODEL_NAME = 'models/llava-v1.6-mistral-7b-hf'


def mem(): return torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024, torch.cuda.memory_allocated() / 1024 / 1024 / 1024


assert torch.cuda.is_available()
device = "cuda"



def multinomial_sample_one_no_sync(probs_sort):  # Does multinomial sampling without a cuda synchronization
    """Copied from https://github.com/pytorch-labs/gpt-fast/blob/f6973170327003c6b1ce7edb5c015b4fa0097e6d/generate.py#L40C1-L42C82"""
    q = torch.empty_like(probs_sort).exponential_(1)
    return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)


def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
    """Copied from https://github.com/pytorch-labs/gpt-fast/blob/f6973170327003c6b1ce7edb5c015b4fa0097e6d/generate.py#L44C1-L52C17"""
    logits = logits / max(temperature, 1e-5)
    if top_k is not None:
        v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
        pivot = v.select(-1, -1).unsqueeze(-1)
        logits = torch.where(logits < pivot, -float("Inf"), logits)
    probs = torch.nn.functional.softmax(logits, dim=-1)
    return probs


def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
    """Copied from https://github.com/pytorch-labs/gpt-fast/blob/f6973170327003c6b1ce7edb5c015b4fa0097e6d/generate.py#L54C1-L57C27"""
    probs = logits_to_probs(logits[:, -1], temperature, top_k)
    idx_next = multinomial_sample_one_no_sync(probs)
    return idx_next, probs


def decode_one_tokens(model, cur_token, cache_position):
    """Copied from https://github.com/pytorch-labs/gpt-fast/blob/f6973170327003c6b1ce7edb5c015b4fa0097e6d/generate.py#L64C1-L68C45"""
    logits = model(cur_token, cache_position=cache_position, return_dict=False, use_cache=True)[0]
    new_token = sample(logits, temperature=0)[0]
    return new_token


decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)


def gen(model, inputs, iters=100):
    print(inputs['input_ids'].shape, inputs['image_sizes'])
    generated_ids = torch.zeros((1, iters), dtype=torch.int, device=device)
    with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=True, enable_math=False):
        output = model(**inputs)
        seq_len, logits = output.loss, output.logits
        cache_position = torch.tensor([seq_len], device=device)
        input_id = sample(logits, temperature=0)[0]
        generated_ids[:, 0] = input_id[:, 0]
    gen_pos = torch.tensor([1], device=device)
    print('post-1st  ', mem())
    with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
        for i in tqdm(range(iters - 1)):
            input_id = decode_one_tokens(model.language_model, input_id.clone(), cache_position)
            generated_ids.index_copy_(1, gen_pos, input_id)
            cache_position += 1
            gen_pos += 1
    print('post-last ', mem())
    return generated_ids


with torch.inference_mode():
    processor = LlavaNextProcessor.from_pretrained(MODEL_NAME)
    url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
    image = Image.open(requests.get(url, stream=True).raw)
    prompt = "[INST] <image>\nWhat is shown in this image? [/INST]"

    torch.cuda.memory._record_memory_history()
    print('pre-model', mem())
    model = LlavaNextForConditionalGeneration.from_pretrained(MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float16).to(device)
    print('pre-cache', mem())
    static_cache = partial(StaticCache, dtype=torch.float16)
    model.language_model._setup_cache(static_cache, max_batch_size=1, max_cache_len=4096)
    print('pre-comp ', mem())
    model.language_model.compile(mode='reduce-overhead', fullgraph=True)
    model.vision_tower.compile(mode='reduce-overhead', fullgraph=True)
    print('pre-proc ', mem())
    inputs = processor(prompt, image, return_tensors="pt").to(device)

    print('pre-gen1 ', mem())
    with catchtime('first compile gen:'):
        out = gen(model, inputs, iters=10)
        # print(out)
        print(processor.decode(out[0], skip_special_tokens=True))
    print('pre-gen2 ', mem())
    with catchtime('second compile gen:'):
        out = gen(model, inputs, iters=100)
        # print(out)
        print(processor.decode(out[0], skip_special_tokens=True))
    # torch.cuda.memory._dump_snapshot("snapshot_full.pickle")

But the issue here is compiling a full fp16 requires more than 16GB vram which is more than what I have for production.

from gptfast.

MDK8888 avatar MDK8888 commented on August 15, 2024

Hey, apologies for the late response! I will look into this and get back to you soon :)

from gptfast.

Related Issues (8)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.