Coder Social home page Coder Social logo

Comments (15)

melihogutcen avatar melihogutcen commented on May 14, 2024 2

Yes, in the first transcription, it is highly slower, but in the second time yes it is much better than the first transcription but still, it is not faster than WhisperTransformers.

For example durations as below:

Duration of the first transcription with WhisperJAX: 60.4s
Duration of the second transcription with WhisperJAX: 20.1s

Duration of the WhisperTransformers: 14.23s

from whisper-jax.

melihogutcen avatar melihogutcen commented on May 14, 2024 1

@themanyone I tried your script with the whisper-large-v2 model, and the output is like below.

2023-04-25 14:58:25.859863: I external/xla/xla/service/service.cc:168] XLA service 0xa2ff910 initialized for platform Interpreter (this does not guarantee that XLA will be used). Devices:
2023-04-25 14:58:25.859891: I external/xla/xla/service/service.cc:176]   StreamExecutor device (0): Interpreter, <undefined>
2023-04-25 14:58:25.868946: I external/xla/xla/pjrt/tfrt_cpu_pjrt_client.cc:218] TfrtCpuClient created.
2023-04-25 14:58:26.411156: I external/xla/xla/service/service.cc:168] XLA service 0xa072e20 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2023-04-25 14:58:26.411209: I external/xla/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA GeForce RTX 3090 Ti, Compute Capability 8.6
2023-04-25 14:58:26.411804: I external/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc:198] Using BFC allocator.
2023-04-25 14:58:26.411887: I external/xla/xla/pjrt/gpu/gpu_helpers.cc:105] XLA backend allocating 19069648896 bytes on device 0 for BFCAllocator.
2023-04-25 14:58:32.620089: I external/xla/xla/stream_executor/cuda/cuda_dnn.cc:424] Loaded cuDNN version 8500
2023-04-25 14:58:32.688973: I external/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc:52] Using nvlink for parallel linking

from whisper-jax.

themanyone avatar themanyone commented on May 14, 2024

It took forever to get flax/jax/cuda installed correctly on this ancient hardware. But FWIW it's not faster on this old brick either. It does take most load off CPU, however... Confirming with Quadro M3000M, Compute Capability 5.2

Try the TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS='' stanza and see if it tells you anything new.

$ TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS='' python
Python 3.10.9 (main, Jan 11 2023, 15:21:40) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from whisper_jax import FlaxWhisperPipline
2023-04-25 03:26:34.596566: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
>>> import jax.numpy as jnp
>>> 
>>> pipeline = FlaxWhisperPipline("openai/whisper-small.en", dtype=jnp.bfloat16, batch_size=16)
2023-04-25 03:26:57.709620: I external/xla/xla/service/service.cc:168] XLA service 0xad4fed0 initialized for platform Interpreter (this does not guarantee that XLA will be used). Devices:
2023-04-25 03:26:57.709677: I external/xla/xla/service/service.cc:176]   StreamExecutor device (0): Interpreter, <undefined>
2023-04-25 03:26:57.713812: I external/xla/xla/pjrt/tfrt_cpu_pjrt_client.cc:218] TfrtCpuClient created.
2023-04-25 03:26:57.779957: I external/xla/xla/stream_executor/cuda/cuda_gpu_executor.cc:997] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2023-04-25 03:26:57.780275: I external/xla/xla/service/service.cc:168] XLA service 0xacc8c40 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2023-04-25 03:26:57.780332: I external/xla/xla/service/service.cc:176]   StreamExecutor device (0): Quadro M3000M, Compute Capability 5.2
2023-04-25 03:26:57.781066: I external/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc:198] Using BFC allocator.
2023-04-25 03:26:57.781166: I external/xla/xla/pjrt/gpu/gpu_helpers.cc:105] XLA backend allocating 3175464960 bytes on device 0 for BFCAllocator.
2023-04-25 03:26:57.807468: I external/xla/xla/pjrt/pjrt_api.cc:86] GetPjrtApi was found for tpu at /home/k/.local/lib/python3.10/site-packages/libtpu/libtpu.so
2023-04-25 03:26:57.807502: I external/xla/xla/pjrt/pjrt_api.cc:58] PJRT_Api is set for device type tpu
2023-04-25 03:26:59.436941: I external/xla/xla/stream_executor/cuda/cuda_dnn.cc:424] Loaded cuDNN version 8900
2023-04-25 03:26:59.627354: I external/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc:52] Using nvlink for parallel linking

from whisper-jax.

themanyone avatar themanyone commented on May 14, 2024

I spoke too soon perhaps. I was not able to get speedup with dtype=jnp.bfloat16, batch_size=16
HOWEVER
When I used the example exactly as given in the README... and without debugging... it works!!!
Decoding is sped up to 1.3 seconds the 2nd time pipeline is called.

JAX_PLATFORMS='' python
from whisper_jax import FlaxWhisperPipline
pipeline = FlaxWhisperPipline("openai/whisper-small.en")
import os

t = time.time(); text = pipeline("test.mp3"); print(time.time() - t)
15.293785810470581

t = time.time(); text = pipeline("test.mp3"); print(time.time() - t)
1.2636265754699707

from whisper-jax.

melihogutcen avatar melihogutcen commented on May 14, 2024

Here, I used recommended parameters (https://huggingface.co/blog/asr-chunking https://colab.research.google.com/drive/1rS1L4YSJqKUH_3YxIQHBI982zso23wor?usp=sharing#scrollTo=Mh_e6rV62QUM) @themanyone

In these two codes parameters, sound data, and environments are the same. Still, there is no acceleration. @sanchit-gandhi

from whisper-jax.

sanchit-gandhi avatar sanchit-gandhi commented on May 14, 2024

Hey @melihogutcen - if I understand correctly, the JAX code you're running only does one transcription step (based on what you've shared here: #44 (comment))? If this is the case, this first transcription step is our compilation step, which we expect to be slow.

If you do a second transcription step, you'll find that Whisper JAX should be extremely fast:

from whisper_jax import FlaxWhisperPipline
import jax.numpy as jnp
import time

# instantiate pipeline with float16 and enable batching
pipeline = FlaxWhisperPipline("openai/whisper-large-v2", dtype=jnp.float16, batch_size=8)

# transcribe and return timestamps - compilation step will be slow
start = time.time()
outputs = pipeline("audio.mp3",  task="transcribe", return_timestamps=True)
runtime = time.time() - start
print("Compilation: ", runtime)

# transcribe again - use cached function, will be fast
start = time.time()
outputs = pipeline("audio.mp3",  task="transcribe", return_timestamps=True)
runtime = time.time() - start
print("Cached: ", runtime)

You can read more about just-in time (JIT) compilation here: https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html#jit-compiling-a-function

from whisper-jax.

erickalfaro avatar erickalfaro commented on May 14, 2024

so this implementation only serves use cases where you intend to transcribe the same audio file more than once?

from whisper-jax.

gptlang avatar gptlang commented on May 14, 2024

so this implementation only serves use cases where you intend to transcribe the same audio file more than once?

Why would you transcribe the same audio file twice? I can get 1000000x performance by caching the text...

from whisper-jax.

sanchit-gandhi avatar sanchit-gandhi commented on May 14, 2024

No, you can transcribe any audio file with this method. You have to run it slowly once (compile). After doing this, you can run it fast on any audio file after that (cached). See the JAX JIT docs for details on JIT: https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html

And the Kaggle Notebook for an application to Whisper JAX: https://www.kaggle.com/code/sgandhi99/whisper-jax-tpu

from whisper-jax.

gptlang avatar gptlang commented on May 14, 2024

Is it possible to pre-compile rather than passing through an audio file first?

from whisper-jax.

sanchit-gandhi avatar sanchit-gandhi commented on May 14, 2024

Yep - you can just pass a dummy log-mel spectrogram, see:

whisper-jax/app/app.py

Lines 86 to 87 in e2dddc9

random_inputs = {"input_features": np.ones((BATCH_SIZE, 80, 3000))}
random_timestamps = pipeline.forward(random_inputs, batch_size=BATCH_SIZE, return_timestamps=True)

After that you can call:

outputs = pipeline("audio.mp3", return_timestamps=True)

from whisper-jax.

HarutyunyanLiana avatar HarutyunyanLiana commented on May 14, 2024

@melihogutcen I have the same issue, I try the original implementation and this one, and the original seems to be faster each time (I am comparing the second time runs). Were you able to get the fast results?

from whisper-jax.

themanyone avatar themanyone commented on May 14, 2024

This issue WAS happening, before totally wiping out the system and upgrading fresh on a new hard drive. This time installing updated video drivers, cuda, and cudnn from the nvidia website, instead of the distro-packaged versions. Now it's super-fast.

from whisper-jax.

klvnptr avatar klvnptr commented on May 14, 2024

Hello Guys, so I'm in the same shoe. Here is my code:

import time
from whisper_jax import FlaxWhisperPipline

pipeline = FlaxWhisperPipline("openai/whisper-large-v2")

t = time.time(); text = pipeline("rec11.mp3"); print(time.time() - t)
t = time.time(); text = pipeline("rec11.mp3"); print(time.time() - t)

first run is: 62 sec
last run is: 38 sec

audio file duration: 00:02:22.87 (from ffmpeg)

output of my nvidia-smi:

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02              Driver Version: 530.30.02    CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 3090         On | 00000000:01:00.0 Off |                  N/A |
|  0%   45C    P8               25W / 350W|  19175MiB / 24576MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
+---------------------------------------------------------------------------------------+

/usr/local/cuda/bin/nvcc -V

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Sep_21_10:33:58_PDT_2022
Cuda compilation tools, release 11.8, V11.8.89
Build cuda_11.8.r11.8/compiler.31833905_0

i was running python with TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=''

got these messages

I external/xla/xla/service/service.cc:168] XLA service 0x560a412f7350 initialized for platform Interpreter (this does not guarantee that XLA will be used). Devices:
I external/xla/xla/service/service.cc:176]   StreamExecutor device (0): Interpreter, <undefined>
I external/xla/xla/pjrt/tfrt_cpu_pjrt_client.cc:433] TfrtCpuClient created.
I external/xla/xla/stream_executor/cuda/cuda_gpu_executor.cc:995] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I external/xla/xla/service/service.cc:168] XLA service 0x560a3fe5d650 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I external/xla/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA GeForce RTX 3090, Compute Capability 8.6
I external/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc:521] Using BFC allocator.
I external/xla/xla/pjrt/gpu/gpu_helpers.cc:105] XLA backend allocating 19065569280 bytes on device 0 for BFCAllocator.
I external/xla/xla/stream_executor/cuda/cuda_dnn.cc:434] Loaded cuDNN version 8700
I external/tsl/tsl/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory
I external/tsl/tsl/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory
I external/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc:52] Using nvlink for parallel linking

I external/xla/xla/stream_executor/gpu/asm_compiler.cc:328] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_1494', 68 bytes spill stores, 68 bytes spill loads

I external/xla/xla/stream_executor/gpu/asm_compiler.cc:328] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_1494', 128 bytes spill stores, 128 bytes spill loads

E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 45: 346.347 vs 387.222
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 70: 350.232 vs 389.543
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 299: 350.822 vs 394.675
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 407: 345.434 vs 394.303
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 533: 352.764 vs 396.02
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 553: 347.322 vs 391.286
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 593: 348.782 vs 393.536
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 694: 343.945 vs 390.099
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 721: 355.291 vs 395.885
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 797: 347.525 vs 390.146
E external/xla/xla/service/gpu/triton_autotuner.cc:377] Results mismatch between different tilings. This is likely a bug/unexpected loss of precision.
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 45: 346.347 vs 387.222
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 70: 350.232 vs 389.543
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 299: 350.822 vs 394.675
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 407: 345.434 vs 394.303
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 533: 352.764 vs 396.02
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 553: 347.322 vs 391.286
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 593: 348.782 vs 393.536
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 694: 343.945 vs 390.099
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 721: 355.291 vs 395.885
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 797: 347.525 vs 390.146
E external/xla/xla/service/gpu/triton_autotuner.cc:377] Results mismatch between different tilings. This is likely a bug/unexpected loss of precision.

i would appreciate any direction, how to find the bottleneck. faster-whisper gets the this job done much faster, so i'm guessing something is not okay with my setup.

thank you.

from whisper-jax.

themanyone avatar themanyone commented on May 14, 2024

I could not install torch in the same venv as whisper-jax when making my open-source Whisper Dictation app here. Doing so would downgrade nvidia-cudnn-cu11 to a non-working version that would mostly use the CPU. Then I'd have to run pip install --upgrade nvidia-cudnn-cu11 to get it back. So I put it in a venv to keep them separate. I have torch cuda python3.11 running in my main python install. And nvidia-cudnn-cu11, jax[cuda] python3.10 et al. in the virtual environ.

from whisper-jax.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.