Comments (15)
Yes, in the first transcription, it is highly slower, but in the second time yes it is much better than the first transcription but still, it is not faster than WhisperTransformers.
For example durations as below:
Duration of the first transcription with WhisperJAX: 60.4s
Duration of the second transcription with WhisperJAX: 20.1s
Duration of the WhisperTransformers: 14.23s
from whisper-jax.
@themanyone I tried your script with the whisper-large-v2 model, and the output is like below.
2023-04-25 14:58:25.859863: I external/xla/xla/service/service.cc:168] XLA service 0xa2ff910 initialized for platform Interpreter (this does not guarantee that XLA will be used). Devices:
2023-04-25 14:58:25.859891: I external/xla/xla/service/service.cc:176] StreamExecutor device (0): Interpreter, <undefined>
2023-04-25 14:58:25.868946: I external/xla/xla/pjrt/tfrt_cpu_pjrt_client.cc:218] TfrtCpuClient created.
2023-04-25 14:58:26.411156: I external/xla/xla/service/service.cc:168] XLA service 0xa072e20 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2023-04-25 14:58:26.411209: I external/xla/xla/service/service.cc:176] StreamExecutor device (0): NVIDIA GeForce RTX 3090 Ti, Compute Capability 8.6
2023-04-25 14:58:26.411804: I external/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc:198] Using BFC allocator.
2023-04-25 14:58:26.411887: I external/xla/xla/pjrt/gpu/gpu_helpers.cc:105] XLA backend allocating 19069648896 bytes on device 0 for BFCAllocator.
2023-04-25 14:58:32.620089: I external/xla/xla/stream_executor/cuda/cuda_dnn.cc:424] Loaded cuDNN version 8500
2023-04-25 14:58:32.688973: I external/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc:52] Using nvlink for parallel linking
from whisper-jax.
It took forever to get flax/jax/cuda installed correctly on this ancient hardware. But FWIW it's not faster on this old brick either. It does take most load off CPU, however... Confirming with Quadro M3000M, Compute Capability 5.2
Try the TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=''
stanza and see if it tells you anything new.
$ TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS='' python
Python 3.10.9 (main, Jan 11 2023, 15:21:40) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from whisper_jax import FlaxWhisperPipline
2023-04-25 03:26:34.596566: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE4.1 SSE4.2 AVX
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
>>> import jax.numpy as jnp
>>>
>>> pipeline = FlaxWhisperPipline("openai/whisper-small.en", dtype=jnp.bfloat16, batch_size=16)
2023-04-25 03:26:57.709620: I external/xla/xla/service/service.cc:168] XLA service 0xad4fed0 initialized for platform Interpreter (this does not guarantee that XLA will be used). Devices:
2023-04-25 03:26:57.709677: I external/xla/xla/service/service.cc:176] StreamExecutor device (0): Interpreter, <undefined>
2023-04-25 03:26:57.713812: I external/xla/xla/pjrt/tfrt_cpu_pjrt_client.cc:218] TfrtCpuClient created.
2023-04-25 03:26:57.779957: I external/xla/xla/stream_executor/cuda/cuda_gpu_executor.cc:997] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2023-04-25 03:26:57.780275: I external/xla/xla/service/service.cc:168] XLA service 0xacc8c40 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2023-04-25 03:26:57.780332: I external/xla/xla/service/service.cc:176] StreamExecutor device (0): Quadro M3000M, Compute Capability 5.2
2023-04-25 03:26:57.781066: I external/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc:198] Using BFC allocator.
2023-04-25 03:26:57.781166: I external/xla/xla/pjrt/gpu/gpu_helpers.cc:105] XLA backend allocating 3175464960 bytes on device 0 for BFCAllocator.
2023-04-25 03:26:57.807468: I external/xla/xla/pjrt/pjrt_api.cc:86] GetPjrtApi was found for tpu at /home/k/.local/lib/python3.10/site-packages/libtpu/libtpu.so
2023-04-25 03:26:57.807502: I external/xla/xla/pjrt/pjrt_api.cc:58] PJRT_Api is set for device type tpu
2023-04-25 03:26:59.436941: I external/xla/xla/stream_executor/cuda/cuda_dnn.cc:424] Loaded cuDNN version 8900
2023-04-25 03:26:59.627354: I external/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc:52] Using nvlink for parallel linking
from whisper-jax.
I spoke too soon perhaps. I was not able to get speedup with dtype=jnp.bfloat16, batch_size=16
HOWEVER
When I used the example exactly as given in the README... and without debugging... it works!!!
Decoding is sped up to 1.3 seconds the 2nd time pipeline is called.
JAX_PLATFORMS='' python
from whisper_jax import FlaxWhisperPipline
pipeline = FlaxWhisperPipline("openai/whisper-small.en")
import os
t = time.time(); text = pipeline("test.mp3"); print(time.time() - t)
15.293785810470581
t = time.time(); text = pipeline("test.mp3"); print(time.time() - t)
1.2636265754699707
from whisper-jax.
Here, I used recommended parameters (https://huggingface.co/blog/asr-chunking https://colab.research.google.com/drive/1rS1L4YSJqKUH_3YxIQHBI982zso23wor?usp=sharing#scrollTo=Mh_e6rV62QUM) @themanyone
In these two codes parameters, sound data, and environments are the same. Still, there is no acceleration. @sanchit-gandhi
from whisper-jax.
Hey @melihogutcen - if I understand correctly, the JAX code you're running only does one transcription step (based on what you've shared here: #44 (comment))? If this is the case, this first transcription step is our compilation step, which we expect to be slow.
If you do a second transcription step, you'll find that Whisper JAX should be extremely fast:
from whisper_jax import FlaxWhisperPipline
import jax.numpy as jnp
import time
# instantiate pipeline with float16 and enable batching
pipeline = FlaxWhisperPipline("openai/whisper-large-v2", dtype=jnp.float16, batch_size=8)
# transcribe and return timestamps - compilation step will be slow
start = time.time()
outputs = pipeline("audio.mp3", task="transcribe", return_timestamps=True)
runtime = time.time() - start
print("Compilation: ", runtime)
# transcribe again - use cached function, will be fast
start = time.time()
outputs = pipeline("audio.mp3", task="transcribe", return_timestamps=True)
runtime = time.time() - start
print("Cached: ", runtime)
You can read more about just-in time (JIT) compilation here: https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html#jit-compiling-a-function
from whisper-jax.
so this implementation only serves use cases where you intend to transcribe the same audio file more than once?
from whisper-jax.
so this implementation only serves use cases where you intend to transcribe the same audio file more than once?
Why would you transcribe the same audio file twice? I can get 1000000x performance by caching the text...
from whisper-jax.
No, you can transcribe any audio file with this method. You have to run it slowly once (compile). After doing this, you can run it fast on any audio file after that (cached). See the JAX JIT docs for details on JIT: https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html
And the Kaggle Notebook for an application to Whisper JAX: https://www.kaggle.com/code/sgandhi99/whisper-jax-tpu
from whisper-jax.
Is it possible to pre-compile rather than passing through an audio file first?
from whisper-jax.
Yep - you can just pass a dummy log-mel spectrogram, see:
Lines 86 to 87 in e2dddc9
After that you can call:
outputs = pipeline("audio.mp3", return_timestamps=True)
from whisper-jax.
@melihogutcen I have the same issue, I try the original implementation and this one, and the original seems to be faster each time (I am comparing the second time runs). Were you able to get the fast results?
from whisper-jax.
This issue WAS happening, before totally wiping out the system and upgrading fresh on a new hard drive. This time installing updated video drivers, cuda, and cudnn from the nvidia website, instead of the distro-packaged versions. Now it's super-fast.
from whisper-jax.
Hello Guys, so I'm in the same shoe. Here is my code:
import time
from whisper_jax import FlaxWhisperPipline
pipeline = FlaxWhisperPipline("openai/whisper-large-v2")
t = time.time(); text = pipeline("rec11.mp3"); print(time.time() - t)
t = time.time(); text = pipeline("rec11.mp3"); print(time.time() - t)
first run is: 62 sec
last run is: 38 sec
audio file duration: 00:02:22.87 (from ffmpeg)
output of my nvidia-smi:
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02 Driver Version: 530.30.02 CUDA Version: 12.1 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA GeForce RTX 3090 On | 00000000:01:00.0 Off | N/A |
| 0% 45C P8 25W / 350W| 19175MiB / 24576MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
+---------------------------------------------------------------------------------------+
/usr/local/cuda/bin/nvcc -V
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Sep_21_10:33:58_PDT_2022
Cuda compilation tools, release 11.8, V11.8.89
Build cuda_11.8.r11.8/compiler.31833905_0
i was running python with TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=''
got these messages
I external/xla/xla/service/service.cc:168] XLA service 0x560a412f7350 initialized for platform Interpreter (this does not guarantee that XLA will be used). Devices:
I external/xla/xla/service/service.cc:176] StreamExecutor device (0): Interpreter, <undefined>
I external/xla/xla/pjrt/tfrt_cpu_pjrt_client.cc:433] TfrtCpuClient created.
I external/xla/xla/stream_executor/cuda/cuda_gpu_executor.cc:995] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I external/xla/xla/service/service.cc:168] XLA service 0x560a3fe5d650 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I external/xla/xla/service/service.cc:176] StreamExecutor device (0): NVIDIA GeForce RTX 3090, Compute Capability 8.6
I external/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc:521] Using BFC allocator.
I external/xla/xla/pjrt/gpu/gpu_helpers.cc:105] XLA backend allocating 19065569280 bytes on device 0 for BFCAllocator.
I external/xla/xla/stream_executor/cuda/cuda_dnn.cc:434] Loaded cuDNN version 8700
I external/tsl/tsl/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory
I external/tsl/tsl/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory
I external/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc:52] Using nvlink for parallel linking
I external/xla/xla/stream_executor/gpu/asm_compiler.cc:328] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_1494', 68 bytes spill stores, 68 bytes spill loads
I external/xla/xla/stream_executor/gpu/asm_compiler.cc:328] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_1494', 128 bytes spill stores, 128 bytes spill loads
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 45: 346.347 vs 387.222
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 70: 350.232 vs 389.543
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 299: 350.822 vs 394.675
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 407: 345.434 vs 394.303
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 533: 352.764 vs 396.02
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 553: 347.322 vs 391.286
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 593: 348.782 vs 393.536
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 694: 343.945 vs 390.099
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 721: 355.291 vs 395.885
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 797: 347.525 vs 390.146
E external/xla/xla/service/gpu/triton_autotuner.cc:377] Results mismatch between different tilings. This is likely a bug/unexpected loss of precision.
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 45: 346.347 vs 387.222
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 70: 350.232 vs 389.543
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 299: 350.822 vs 394.675
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 407: 345.434 vs 394.303
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 533: 352.764 vs 396.02
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 553: 347.322 vs 391.286
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 593: 348.782 vs 393.536
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 694: 343.945 vs 390.099
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 721: 355.291 vs 395.885
E external/xla/xla/service/gpu/buffer_comparator.cc:731] Difference at 797: 347.525 vs 390.146
E external/xla/xla/service/gpu/triton_autotuner.cc:377] Results mismatch between different tilings. This is likely a bug/unexpected loss of precision.
i would appreciate any direction, how to find the bottleneck. faster-whisper gets the this job done much faster, so i'm guessing something is not okay with my setup.
thank you.
from whisper-jax.
I could not install torch in the same venv as whisper-jax when making my open-source Whisper Dictation app here. Doing so would downgrade nvidia-cudnn-cu11 to a non-working version that would mostly use the CPU. Then I'd have to run pip install --upgrade nvidia-cudnn-cu11
to get it back. So I put it in a venv
to keep them separate. I have torch cuda python3.11 running in my main python install. And nvidia-cudnn-cu11, jax[cuda] python3.10 et al. in the virtual environ.
from whisper-jax.
Related Issues (20)
- ERROR: Could not find a version that satisfies the requirement jaxlib==0.4.5 (from versions: 0.4.18, 0.4.19, 0.4.20, 0.4.21) HOT 1
- there is a requirements.txt file of whisper-jax? HOT 2
- Using mulaw audio buffer data
- The demo throws error when uploading file
- Is there some code for Whisper jax to produce srt subtitle? HOT 1
- How to add millisecond for the timestamp?
- I have downloaded the flax_model, where can I call it?
- why whisper-jax did not use my GPU? HOT 3
- Rust impl
- Unsuccessful deployment HOT 1
- Coral TPU support HOT 1
- Slower than openai whisper with my gpu HOT 2
- I want to use whisper-at models HOT 1
- Has translate be integrated into transcribe? It returns English but expect Chinese. HOT 3
- Slow post processing HOT 1
- unable to run TPU using current kaggle environment HOT 1
- Large Model causing performance degradation?
- Shape Error when running on GPU HOT 2
- HuggingFace space erroring more often than usual HOT 1
- Transcription issues.
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from whisper-jax.