sanchit-gandhi / whisper-jax Goto Github PK
View Code? Open in Web Editor NEWJAX implementation of OpenAI's Whisper model for up to 70x speed-up on TPU.
License: Apache License 2.0
JAX implementation of OpenAI's Whisper model for up to 70x speed-up on TPU.
License: Apache License 2.0
Thanks for you nice project. The openai whisper don't open source the train code. Can you project implement it? When I use large-v2 model, it always gives youtube video advertise. So it is a problem of the train data. I want to train a model with clean data. The problem is discussed below:
openai/whisper#928
Windows 10
Miniconda3
Python3.9
jaxlib-0.3.25
jax-0.3.25
numpy-1.20.3
When I try to import using : from whisper_jax import FlaxWhisperPipline
I get this error, I am new in JAX so anyhelp is welcome.
cannot import name 'dot_product_attention_weights' from 'flax.linen.attention'
What is the estimated first JIT compile time on a Colab Premium GPU (A100)? I'm talking about the code right below this line:
# JIT compile the forward call - slow, but we only do once
Is there anyone who uses whisper-jax to extract logits from audio?
I cannot seem to get rid of this on google colab:
AttributeError: module 'jax.tree_util' has no attribute 'register_pytree_with_keys_class'
anyway to use the command line for the project? looking for example, thanks
While operating on kaggle this is the error I encounter the error -
ffmpeg was not found but is required to load audio files from filename.
Code -
def process_doc(file):
wav_path=os.path.join("/kaggle/input/upwork-calls/CSG_CALLS",f"{file}")
doc_path=os.path.join("/kaggle/working/docs",f"doc_{file}")
if not os.path.exists(doc_path):
os.mkdir(doc_path)
for files in tqdm(os.listdir(wav_path)):
filename=files.split(".")[0]+".docx"
result = pipeline(os.path.join(wav_path,files),task="transcribe")
mydoc = docx.Document()
mydoc.add_paragraph(result['text'])
mydoc.save(os.path.join(doc_path,filename))
print("------------------- Saved in path ---------------- : ",doc_path)
I tried to load ffmpeg using
!apt-get install -y ffmpeg > /dev/null
failed with error - E: Package 'ffmpeg' has no installation candidate
Can anyone please help me with the issue
When a Numpy array is passed in, the model runs fine, but this causes the model to perform poorly because the audio array is not resampled to the appropriate sample rate.
This is fixed by passing a dict with array
and sampling_rate
keys.
I have just started working with this awesome repository. One way to improve the user experience would be to create a requirements.txt file to install the required frameworks for this repository to work.
The three frameworks that need to be installed are gradio
, pytube
and transformers
.
I tried to get transcriptions for a video of David Silver's reinforcement learning playlist from YouTube .
The model was able to generate very good transcriptions at some timestamps , but at many timestamps , it generates transcriptions of some other language which apart from English. I haven't changed any settings or anything , just copy pasted the url of the video and clicked on transcribe . The result was out in 23.4 seconds but wasn't accurate .
For more information , please have a look at this image I'm attaching below :
In the image , you can clearly observe that the model is generating transcriptions of other language , even though english is asked for . Some part of it was in English , and the other part in some other language . #
Hi,
I couldn't get faster results. Whisper transformers are faster than Jax implementation.
jax ==0.4.8
jaxlib==0.4.7+cuda11.cudnn82
transformers==4.28.1
CUDA Version: 11.7
Python 3.9.16
GPU: RTX 3090 Ti
Transformers Implementation:
from transformers import pipeline
MODEL_NAME ="openai/whisper-large-v2"
pipe = pipeline(
task="automatic-speech-recognition",
model=MODEL_NAME,
device='cuda:0',
generate_kwargs = {"language":"<|tr|>","task": "transcribe"}
)
text = pipe(16k_sound,
return_timestamps=True,
chunk_length_s=30.0,
stride_length_s=[6,0],
batch_size=8,
generate_kwargs = {"language":"<|tr|>","task": "transcribe"})
JAX Implementation:
from whisper_jax import FlaxWhisperPipline
import jax.numpy as jnp
MODEL_NAME ="openai/whisper-large-v2"
pipeline = FlaxWhisperPipline(MODEL_NAME,dtype=jnp.float16)
text = pipeline(16k_sound,
return_timestamps=True,
chunk_length_s=30.0,
stride_length_s=[6,0],
batch_size=8,
generate_kwargs = {"language":"<|tr|>","task": "transcribe"})
here I tried 3-4 times but I couldn't decrease the computation time.
What is the huggingface model? Not the space, the model
Hi- appreciate sharing of this framework, it looks very useful
I'm wondering if it's possible to do real-time transcriptions using
from transformers.pipelines.audio_utils import ffmpeg_microphone_live
as detailed in this PR:
I have a RTX 4090 and running
import json
import jax.numpy as jnp
from whisper_jax import FlaxWhisperPipline
def transcribe_70():
# instantiate pipeline
pipeline = FlaxWhisperPipline("openai/whisper-large-v2", batch_size=16, dtype=jnp.bfloat16)
outputs = pipeline("audio.mp3", task="transcribe", return_timestamps=True)
# used cached function thereafter - super fast!!
with open("output70.json", "w") as f:
f.write(json.dumps(outputs))
if __name__ == '__main__':
transcribe_70()
gives me:
2023-04-21 08:22:07.870777: I external/xla/xla/service/service.cc:168] XLA service 0x56074ee3c980 initialized for platform Interpreter (this does not guarantee that XLA will be used). Devices:
2023-04-21 08:22:07.870792: I external/xla/xla/service/service.cc:176] StreamExecutor device (0): Interpreter, <undefined>
2023-04-21 08:22:07.873147: I external/xla/xla/pjrt/tfrt_cpu_pjrt_client.cc:218] TfrtCpuClient created.
2023-04-21 08:22:07.873292: I external/xla/xla/stream_executor/tpu/tpu_initializer_helper.cc:269] Libtpu path is: libtpu.so
2023-04-21 08:22:07.873367: I external/xla/xla/stream_executor/tpu/tpu_initializer_helper.cc:277] Failed to open libtpu: libtpu.so: cannot open shared object file: No such file or directory
2023-04-21 08:22:07.873389: I external/xla/xla/stream_executor/tpu/tpu_platform_interface.cc:73] No TPU platform found.
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Using whisper in the same virtual env works with the GPU.
Is there a recommended method to implement speaker diarization with this whisper solution?
I executed the command python app.py
and provided a YouTube video link through the web interface, but received the following error message:
Traceback (most recent call last):
File "C:\Users\rosha\Downloads\Compressed\whisper-jax-main\whisper\lib\site-packages\transformers\pipelines\audio_utils.py", line 34, in ffmpeg_read
with subprocess.Popen(ffmpeg_command, stdin=subprocess.PIPE, stdout=subprocess.PIPE) as ffmpeg_process:
File "C:\Users\rosha\AppData\Local\Programs\Python\Python310\lib\subprocess.py", line 969, in __init__
self._execute_child(args, executable, preexec_fn, close_fds,
File "C:\Users\rosha\AppData\Local\Programs\Python\Python310\lib\subprocess.py", line 1438, in _execute_child
hp, ht, pid, tid = _winapi.CreateProcess(executable, args,
FileNotFoundError: [WinError 2] The system cannot find the file specified
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "C:\Users\rosha\Downloads\Compressed\whisper-jax-main\whisper\lib\site-packages\gradio\routes.py", line 401, in run_predict
output = await app.get_blocks().process_api(
File "C:\Users\rosha\Downloads\Compressed\whisper-jax-main\whisper\lib\site-packages\gradio\blocks.py", line 1302, in process_api
result = await self.call_function(
File "C:\Users\rosha\Downloads\Compressed\whisper-jax-main\whisper\lib\site-packages\gradio\blocks.py", line 1025, in call_function
prediction = await anyio.to_thread.run_sync(
File "C:\Users\rosha\Downloads\Compressed\whisper-jax-main\whisper\lib\site-packages\anyio\to_thread.py", line 31, in run_sync
return await get_asynclib().run_sync_in_worker_thread(
File "C:\Users\rosha\Downloads\Compressed\whisper-jax-main\whisper\lib\site-packages\anyio\_backends\_asyncio.py", line 937, in run_sync_in_worker_thread
return await future
File "C:\Users\rosha\Downloads\Compressed\whisper-jax-main\whisper\lib\site-packages\anyio\_backends\_asyncio.py", line 867, in run
result = context.run(func, *args)
File "C:\Users\rosha\Downloads\Compressed\whisper-jax-main\app\app.py", line 183, in transcribe_youtube
inputs = ffmpeg_read(inputs, processor.feature_extractor.sampling_rate)
File "C:\Users\rosha\Downloads\Compressed\whisper-jax-main\whisper\lib\site-packages\transformers\pipelines\audio_utils.py", line 37, in ffmpeg_read
raise ValueError("ffmpeg was not found but is required to load audio files from filename") from error
ValueError: ffmpeg was not found but is required to load audio files from filename
I have added ffmpeg
to the path as well as I have also installed ffmpeg-python
but still the same issue.
In case I select the Microphone
tab and record the audio and click submit I get the following error:
C:\Users\rosha\Downloads\Compressed\whisper-jax-main\whisper\lib\site-packages\pydub\utils.py:198: RuntimeWarning: Couldn't find ffprobe or avprobe - defaulting to ffprobe, but may not work
warn("Couldn't find ffprobe or avprobe - defaulting to ffprobe, but may not work", RuntimeWarning)
Traceback (most recent call last):
File "C:\Users\rosha\Downloads\Compressed\whisper-jax-main\whisper\lib\site-packages\gradio\processing_utils.py", line 138, in audio_from_file
audio = AudioSegment.from_file(filename)
File "C:\Users\rosha\Downloads\Compressed\whisper-jax-main\whisper\lib\site-packages\pydub\audio_segment.py", line 728, in from_file
info = mediainfo_json(orig_file, read_ahead_limit=read_ahead_limit)
File "C:\Users\rosha\Downloads\Compressed\whisper-jax-main\whisper\lib\site-packages\pydub\utils.py", line 274, in mediainfo_json
res = Popen(command, stdin=stdin_parameter, stdout=PIPE, stderr=PIPE)
File "C:\Users\rosha\AppData\Local\Programs\Python\Python310\lib\subprocess.py", line 969, in __init__
self._execute_child(args, executable, preexec_fn, close_fds,
File "C:\Users\rosha\AppData\Local\Programs\Python\Python310\lib\subprocess.py", line 1438, in _execute_child
hp, ht, pid, tid = _winapi.CreateProcess(executable, args,
FileNotFoundError: [WinError 2] The system cannot find the file specified
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "C:\Users\rosha\Downloads\Compressed\whisper-jax-main\whisper\lib\site-packages\gradio\routes.py", line 401, in run_predict
output = await app.get_blocks().process_api(
File "C:\Users\rosha\Downloads\Compressed\whisper-jax-main\whisper\lib\site-packages\gradio\blocks.py", line 1300, in process_api
inputs = self.preprocess_data(fn_index, inputs, state)
File "C:\Users\rosha\Downloads\Compressed\whisper-jax-main\whisper\lib\site-packages\gradio\blocks.py", line 1148, in preprocess_data
processed_input.append(block.preprocess(inputs[i]))
File "C:\Users\rosha\Downloads\Compressed\whisper-jax-main\whisper\lib\site-packages\gradio\components.py", line 2425, in preprocess
sample_rate, data = processing_utils.audio_from_file(
File "C:\Users\rosha\Downloads\Compressed\whisper-jax-main\whisper\lib\site-packages\gradio\processing_utils.py", line 148, in audio_from_file
raise RuntimeError(msg) from e
RuntimeError: Cannot load audio from file: `ffprobe` not found. Please install `ffmpeg` in your system to use non-WAV audio file formats and make sure `ffprobe` is in your PATH.
Hi , thanks for the Jax code , Are there any plans for distilling the existing/orignal model ?
Hello @sanchit-gandhi Thanks for sharing this repo.
I installed all the dependencies and ran this command in terminal 1 bash launch_app.sh
in terminal 2 I ran API_URL=http://0.0.0.0:8000/generate/ API_URL_FROM_FEATURES=http://0.0.0.0:8000/gnerate_from_features/ python app.py
when I select you tube url getting this error
File "/home/ubuntu/whisper-jax/app/app.py", line 72, in forward
outputs["tokens"] = np.asarray(outputs["tokens"])
KeyError: 'tokens'
Complete error
Traceback (most recent call last):
File "/home/ubuntu/anaconda3/envs/whisper/lib/python3.9/site-packages/gradio/routes.py", line 401, in run_predict
output = await app.get_blocks().process_api(
File "/home/ubuntu/anaconda3/envs/whisper/lib/python3.9/site-packages/gradio/blocks.py", line 1302, in process_api
result = await self.call_function(
File "/home/ubuntu/anaconda3/envs/whisper/lib/python3.9/site-packages/gradio/blocks.py", line 1025, in call_function
prediction = await anyio.to_thread.run_sync(
File "/home/ubuntu/anaconda3/envs/whisper/lib/python3.9/site-packages/anyio/to_thread.py", line 31, in run_sync
return await get_asynclib().run_sync_in_worker_thread(
File "/home/ubuntu/anaconda3/envs/whisper/lib/python3.9/site-packages/anyio/_backends/_asyncio.py", line 937, in run_sync_in_worker_thread
return await future
File "/home/ubuntu/anaconda3/envs/whisper/lib/python3.9/site-packages/anyio/_backends/_asyncio.py", line 867, in run
result = context.run(func, *args)
File "/home/ubuntu/whisper-jax/app/app.py", line 185, in transcribe_youtube
text, runtime = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps, progress=progress)
File "/home/ubuntu/whisper-jax/app/app.py", line 126, in tqdm_generate
model_outputs.append(forward(batch, task=task, return_timestamps=return_timestamps))
File "/home/ubuntu/whisper-jax/app/app.py", line 72, in forward
outputs["tokens"] = np.asarray(outputs["tokens"])
KeyError: 'tokens'
Traceback (most recent call last):
File "/home/ubuntu/anaconda3/envs/whisper/lib/python3.9/site-packages/gradio/routes.py", line 401, in run_predict
output = await app.get_blocks().process_api(
File "/home/ubuntu/anaconda3/envs/whisper/lib/python3.9/site-packages/gradio/blocks.py", line 1302, in process_api
result = await self.call_function(
File "/home/ubuntu/anaconda3/envs/whisper/lib/python3.9/site-packages/gradio/blocks.py", line 1025, in call_function
prediction = await anyio.to_thread.run_sync(
File "/home/ubuntu/anaconda3/envs/whisper/lib/python3.9/site-packages/anyio/to_thread.py", line 31, in run_sync
return await get_asynclib().run_sync_in_worker_thread(
File "/home/ubuntu/anaconda3/envs/whisper/lib/python3.9/site-packages/anyio/_backends/_asyncio.py", line 937, in run_sync_in_worker_thread
return await future
File "/home/ubuntu/anaconda3/envs/whisper/lib/python3.9/site-packages/anyio/_backends/_asyncio.py", line 867, in run
result = context.run(func, *args)
File "/home/ubuntu/whisper-jax/app/app.py", line 185, in transcribe_youtube
text, runtime = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps, progress=progress)
File "/home/ubuntu/whisper-jax/app/app.py", line 126, in tqdm_generate
model_outputs.append(forward(batch, task=task, return_timestamps=return_timestamps))
File "/home/ubuntu/whisper-jax/app/app.py", line 72, in forward
outputs["tokens"] = np.asarray(outputs["tokens"])
KeyError: 'tokens'
AttributeError: module 'jax.tree_util' has no attribute 'register_pytree_with_keys_class'
what version of jax and jaxlib works with this?
Hi,
I'm glad to have discovered this place, and after hearing how much speed can be increased, I can't wait to give it a try
Is Jax only installed on Linux?
Please forgive my poor English, the above is translated
Trying to load medium or large model, I get out of memory errors. Loading small with float16 precision works but takes all my 24 GB VRAM. Is there any way to limit Jax memory usage? The OpenAI model is far more modest in its requirements. Reducing the model weights to float16 should be a good idea too.
The original whisper model can take an initial_prompt
value to improve accuracy of the transcript. Is this possible in this improved version of whisper? It really helps a lot for context words.
Let say I want to translate English into German etc.
So how I can do this both for translation and transcription. ?
File "/code/code.py", line 82, in
result = transcribe(video_converted,language)
File "/code/codeTranscript.py", line 10, in transcribe
return transcribe_jax(audio,language=None)
File "/code/codeTranscript.py", line 25, in transcribe_jax
pipeline = FlaxWhisperPipline("models/whisper/large-v2.pt", batch_size=8)
File "/code/venv-3.10/lib/python3.10/site-packages/whisper_jax/pipeline.py", line 84, in init
self.processor = WhisperProcessor.from_pretrained(self.checkpoint)
File "/code/venv-3.10/lib/python3.10/site-packages/transformers/processing_utils.py", line 184, in from_pretrained
args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs)
File "/code/venv-3.10/lib/python3.10/site-packages/transformers/processing_utils.py", line 228, in _get_arguments_from_pretrained
args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
File "/code/venv-3.10/lib/python3.10/site-packages/transformers/feature_extraction_utils.py", line 329, in from_pretrained
feature_extractor_dict, kwargs = cls.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
File "/code/venv-3.10/lib/python3.10/site-packages/transformers/feature_extraction_utils.py", line 457, in get_feature_extractor_dict
text = reader.read()
File "/opt/homebrew/Cellar/[email protected]/3.10.11/Frameworks/Python.framework/Versions/3.10/lib/python3.10/codecs.py", line 322, in decode
(result, consumed) = self._buffer_decode(data, self.errors, final)
UnicodeDecodeError: 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte
Is there way to enforce the model to transcribe only specific language for e.g. Hindi Language?
It would be nice to know how this compare to the ggml-based whisper.cpp implemnetation.
How can I pass the path of the input audiofile in the pipeline? In the kaggle notebook you are passing a dataset, should we just replace by a path to our input or is it another way?
Hi there, i got sanchit's example from other issue working, but speedup is only 8-10x real-time on RTX 4090. GPU is being used 100%, as i can tell from nvtop.
Maybe following error is the reason?
2023-04-28 14:27:07.026383: W tensorflow[/compiler/tf2tensorrt/utils/py_utils.cc:38](https://file+.vscode-resource.vscode-cdn.net/compiler/tf2tensorrt/utils/py_utils.cc:38)] TF-TRT Warning: Could not find TensorRT
Compilation: 198.21863865852356
Cached: 173.1455545425415
I'm not able to get the transcription with words timestamps. Only sentences timestamps.
If this possible with whisper-jax?
Thanks
I don't get how I can link to the model on a local indtall. Should I replace /openai/largev2/ by the path of my model on the disk?
And should I download all feom the folder from huggingface or should I just download the flax file?
Hello, I want to confirm whether the implementation of OpenAI in the benchmark uses the openai-whisper library or the WhisperForConditionalGeneration model of Hugging Face? At the same time, I also want to confirm whether the Hugging Face implementation uses the FlaxWhipserForConditionalGeneration model?
If the OpenAI implementation uses the model in openai-whisper, is the performance test the execution time of DecodingTask.run()?
is there a way to specify the device when loading the pipeline? it doesn't seem possible to pass the device id like you'd be able to do with the 🤗pipeline like:
pipe = FlaxWhisperPipline("openai/whisper-large-v2", device=0, dtype=jnp.bfloat16, batch_size=16)
I'm running a benchmark on multiple models/pipelines and whisper jax takes up all the VRAM available on the 2 GPUs I have (A100 80GB), which causes an OOM error when I try to process an audio file.
I'd like to have the possiblity to load whisper jax on device 0 and the other models on any other devices I have.
please recommend a way to do something like this
I am getting the following error when using "openai/whisper-medium" model with timestamp prediction:
There was an error while processing timestamps, we haven't found a timestamp as last token. Was WhisperTimeStampLogitsProcessor used?
This error comes from "transformers/models/whisper/tokenization_whisper.py" line 885. The generated tokens do not include any timestamps, except for the first one (0.0).
I have tested to use audios of different length (1min to 1h) and different parameters (half-precision, stride) and always the same error occurs. On the other hand, with the base-model and large-v2-model this error does not occur.
Code:
model = "openai/whisper-medium"
whisper = FlaxWhisperPipline(model, dtype=jnp.float32)
res: dict = whisper(audio_file, stride_length_s=0.0, language="es", return_timestamps=True)
My computer:
Hey,
Appreciate your work, it is amazing. I wanted to use the model that I have created with the .ckpt extension. I've found the issue #17 however you have answered it as
Download the entire repository to your local system, and then pass the path to this folder. E.g. if I cloned [this checkpoint](https://huggingface.co/sanchit-gandhi/whisper-small-hi) into a folder called whisper-small-hi, I would pass ./whisper-small-hi
I do not have any folder for my ckpt file. The model is only that file which is larger than 10GB. When I try to pass that file at:
cc.initialize_cache("./jax_cache")
checkpoint = "my_checkpoint.ckpt"
BATCH_SIZE = 16
CHUNK_LENGTH_S = 30
NUM_PROC = 8
FILE_LIMIT_MB = 1000
YT_ATTEMPT_LIMIT = 3
It produces the error:
UnicodeDecodeError: 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte
Any help would be great. Thanks a lot in advance.
Awesome repo! I have one question tho: Whenever I try running this code on my own TPU-v4-8, I get the following error:
WARNING:absl:Tiling device assignment mesh by hosts, which may lead to reduced XLA collective performance. To avoid this, modify the model parallel submesh or run with more tasks per host.
Traceback (most recent call last):
File "fastapi_app.py", line 17, in <module>
pipeline.shard_params()
File "/root/ai/whisper-jax/whisper_jax/pipeline.py", line 127, in shard_params
self.params = p_shard_params(freeze(self.params))
File "/root/ai/whisper-jax/whisper_jax/partitioner.py", line 787, in __call__
return self._pjitted_fn(*args)
File "/usr/local/lib/python3.8/dist-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/jax/_src/pjit.py", line 238, in cache_miss
outs, out_flat, out_tree, args_flat = _python_pjit_helper(
File "/usr/local/lib/python3.8/dist-packages/jax/_src/pjit.py", line 193, in _python_pjit_helper
raise ValueError(msg) from None
jax._src.traceback_util.UnfilteredStackTrace: ValueError: Received incompatible devices for pjitted computation. Got argument params['model']['decoder']['embed_positions']['embedding'] of FlaxPreTrainedModel.to_bf16 with shape float32[448,1280] and device ids [0] on platform CPU and pjit's devices with device ids [0, 2, 1, 3] on platform TPU
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "fastapi_app.py", line 17, in <module>
pipeline.shard_params()
File "/root/ai/whisper-jax/whisper_jax/pipeline.py", line 127, in shard_params
self.params = p_shard_params(freeze(self.params))
File "/root/ai/whisper-jax/whisper_jax/partitioner.py", line 787, in __call__
return self._pjitted_fn(*args)
ValueError: Received incompatible devices for pjitted computation. Got argument params['model']['decoder']['embed_positions']['embedding'] of FlaxPreTrainedModel.to_bf16 with shape float32[448,1280] and device ids [0] on platform CPU and pjit's devices with device ids [0, 2, 1, 3] on platform TPU
Any idea how I can fix it?
I used the default settings on the Kaggle notebook.
https://huggingface.co/datasets/sanchit-gandhi/whisper-jax-test-files
Hi! I am running on WSL2 with an RTX 3090.
I've noticed that faster-whisper runs about twice as fast on my 16k sampled 30s audio clip.
Is that to be expected or did I do something wrong with my JAX installation?
whisper-jax takes about 10s (once cached), while faster-whisper takes 5.1s
I set the faster-whisper beam_size to 1, is there an equivalent setting for whisper-jax?
curious to know if it runs well with a fine-tuned whisper model using PEFT?
is it possible to load it in int8?
Hi! As the title says, my GPU is not being recognized No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
but any other CUDA code (also OpenAI/whisper) does detect my GPU.
Thank you for the help!
Hey,
I'm assuming this is a JAX issue, but I'm getting the following errors when trying to run the notebook on Google Colab:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
[<ipython-input-13-308fe9e13fe9>](https://pw2dauh3d9-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20230419-060138-RC00_525408879#) in <cell line: 1>()
----> 1 from whisper_jax import FlaxWhisperPipline
2 import jax.numpy as jnp
3
4 pipeline = FlaxWhisperPipline("openai/whisper-medium", dtype=jnp.bfloat16, batch_size=16)
4 frames
[/usr/local/lib/python3.9/dist-packages/flax/core/frozen_dict.py](https://pw2dauh3d9-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20230419-060138-RC00_525408879#) in <module>
48
49
---> 50 @jax.tree_util.register_pytree_with_keys_class
51 class FrozenDict(Mapping[K, V]):
52 """An immutable variant of the Python dict."""
AttributeError: module 'jax.tree_util' has no attribute 'register_pytree_with_keys_class'
I've already tried the hints mentions on JAX' Github page, but no success:
# tpu
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
!pip install "jax<=0.3.25" "jaxlib<=0.3.25"
# gpu
import jax
jax.devices()
Hey all,
Very interesting work! I am trying to recreate some of the results you have in table 1.
Do you happen to have the script + audio used on hand? I am having trouble matching it on my machine:
from whisper_jax import FlaxWhisperPipline
import jax.numpy as jnp
import time
import librosa
SAMPLING_RATE = 16000
audio, sr = librosa.load('test_audio.mp3', sr=SAMPLING_RATE)
# instantiate pipeline in bfloat16
pipeline = FlaxWhisperPipline("openai/whisper-large-v2", dtype=jnp.float16, batch_size=32)
print("Warmup compiling forward pass")
text = pipeline(audio)
start_time = time.time()
for i in range(10):
print(f"Go iter {i}")
text = pipeline(audio)
end_time = time.time()
print(text)
print(f"Took {end_time - start_time} s")
# Took 330.93562269210815 s
test_audio.mp3 is a 13 min ted talk clip. I get about 30s per transcription iteration with this. Could be a bunch of things, but just want to know if this code would expect to give the benchmark results under optimal config.
thank you,i want set cache dir
I have the following specs:
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02 Driver Version: 530.30.02 CUDA Version: 12.1 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA GeForce RTX 3050 T... On | 00000000:01:00.0 Off | N/A |
| N/A 45C P5 8W / 60W| 54MiB / 4096MiB | 41% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
I see the following warning before the program is killed:
W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
I do not see other errors:
python whisperJAX.py
2023-04-23 22:28:46.200680: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Killed
How can I resolve this issue? Please let me know if I need to share any more details
This issue occurs when I provide a Youtube link. I'm on Windows 11 (Python 3.10.6) using command python app.py
Traceback (most recent call last):
File "/home/ethan/.local/lib/python3.10/site-packages/gradio/routes.py", line 401, in run_predict
output = await app.get_blocks().process_api(
File "/home/ethan/.local/lib/python3.10/site-packages/gradio/blocks.py", line 1302, in process_api
result = await self.call_function(
File "/home/ethan/.local/lib/python3.10/site-packages/gradio/blocks.py", line 1025, in call_function
prediction = await anyio.to_thread.run_sync(
File "/home/ethan/.local/lib/python3.10/site-packages/anyio/to_thread.py", line 31, in run_sync
return await get_asynclib().run_sync_in_worker_thread(
File "/home/ethan/.local/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 937, in run_sync_in_worker_thread
return await future
File "/home/ethan/.local/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 867, in run
result = context.run(func, *args)
File "/mnt/c/Users/rosha/whisper-jax/app/app.py", line 185, in transcribe_youtube
text, runtime = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps, progress=progress)
File "/mnt/c/Users/rosha/whisper-jax/app/app.py", line 126, in tqdm_generate
model_outputs.append(forward(batch, task=task, return_timestamps=return_timestamps))
File "/mnt/c/Users/rosha/whisper-jax/app/app.py", line 69, in forward
outputs = chunked_query(
File "/mnt/c/Users/rosha/whisper-jax/app/app.py", line 62, in chunked_query
response = requests.post(API_URL_FROM_FEATURES, json=payload)
File "/home/ethan/.local/lib/python3.10/site-packages/requests/api.py", line 119, in post
return request('post', url, data=data, json=json, **kwargs)
File "/home/ethan/.local/lib/python3.10/site-packages/requests/api.py", line 61, in request
return session.request(method=method, url=url, **kwargs)
File "/home/ethan/.local/lib/python3.10/site-packages/requests/sessions.py", line 528, in request
prep = self.prepare_request(req)
File "/home/ethan/.local/lib/python3.10/site-packages/requests/sessions.py", line 456, in prepare_request
p.prepare(
File "/home/ethan/.local/lib/python3.10/site-packages/requests/models.py", line 316, in prepare
self.prepare_url(url, params)
File "/home/ethan/.local/lib/python3.10/site-packages/requests/models.py", line 390, in prepare_url
raise MissingSchema(error)
requests.exceptions.MissingSchema: Invalid URL 'None': No schema supplied. Perhaps you meant http://None?
Thank you, @sanchit-gandhi, for your fantastic work. I would appreciate your opinion on configuring an AWS machine for deploying Hugging Face's Whisper large model (JAX version) and data storage for both audio and streamed textual data.
My end goal is to deploy the stream output model, but for now, I am setting up the current model without steam functionality. What would be the optimal AWS configuration to consider the future scope of the project?
I have checked main page and kaggle and there is no example of these
In reguler I was doing like below
For whisper jax how can I do?
result = model.transcribe("../input/whisper2/lecture_"+str(lectureId)+".mp3",language="en",beam_size=10,initial_prompt="Welcome to the Software Engineering Courses channel.",best_of=10,verbose=True,temperature=0.0)
# save SRT
language = result["language"]
sub_name = f"/kaggle/working/lecture_"+str(lectureId)+".srt"
with open(sub_name, "w", encoding="utf-8") as srt:
write_srt(result["segments"], file=srt)
# Save output
writing_lut = {
'.txt': whisper.utils.write_txt,
'.vtt': whisper.utils.write_vtt,
'.srt': whisper.utils.write_txt,
}
Looks like JAX does not support accelerated M1 / Apple Neural Engine.
Curious if anyone has done a benchmark comparison. Reference:
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.