Coder Social home page Coder Social logo

fast_gpt2's People

Contributors

narsil avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

fast_gpt2's Issues

Cargo fails to build on Apple M1

Here is what I got:

$ git clone https://github.com/Narsil/fast_gpt2
Cloning into 'fast_gpt2'...
remote: Enumerating objects: 405, done.
remote: Counting objects: 100% (229/229), done.
remote: Compressing objects: 100% (147/147), done.
remote: Total 405 (delta 153), reused 142 (delta 82), pack-reused 176
Receiving objects: 100% (405/405), 802.48 KiB | 2.57 MiB/s, done.
Resolving deltas: 100% (236/236), done.
$ cd fast_gpt2 
$ cargo run --example run --release
    Updating crates.io index
    Updating git repository `https://github.com/coreylowman/cudarc`
    Updating git repository `https://github.com/Narsil/dfdx`
    Updating git repository `https://github.com/huggingface/safetensors`
    Updating git repository `https://github.com/huggingface/tokenizers`
  Downloaded rayon-core v1.10.2
  Downloaded cc v1.0.79
  Downloaded reqwest v0.11.14
  Downloaded rustix v0.36.8
  Downloaded tokio-macros v1.8.2
  Downloaded tower-layer v0.3.2
  Downloaded tracing v0.1.37
  Downloaded base64 v0.21.0
  Downloaded memmap2 v0.5.10
  Downloaded tower-http v0.3.5
  Downloaded once_cell v1.17.1
  Downloaded proc-macro2 v1.0.51
  Downloaded macro_rules_attribute v0.1.3
  Downloaded security-framework-sys v2.8.0
  Downloaded tinyvec_macros v0.1.1
  Downloaded tempfile v3.4.0
  Downloaded tokio-util v0.7.7
  Downloaded serde_path_to_error v0.1.9
  Downloaded ryu v1.0.12
  Downloaded pin-project v1.0.12
  Downloaded tower-http v0.4.0
  Downloaded thread_local v1.1.7
  Downloaded serde_json v1.0.93
  Downloaded tracing-core v0.1.30
  Downloaded unicode-bidi v0.3.10
  Downloaded syn v1.0.109
  Downloaded serde_derive v1.0.152
  Downloaded unicode-ident v1.0.6
  Downloaded tracing-log v0.1.3
  Downloaded unicode-segmentation v1.10.1
  Downloaded slab v0.4.8
  Downloaded quote v1.0.23
  Downloaded rustversion v1.0.11
  Downloaded mio v0.8.6
  Downloaded tracing-subscriber v0.3.16
  Downloaded spin v0.9.5
  Downloaded tower v0.4.13
  Downloaded thread-tree v0.3.3
  Downloaded regex v1.7.1
  Downloaded crossbeam-utils v0.8.15
  Downloaded darling v0.14.3
  Downloaded crossbeam-channel v0.5.7
  Downloaded darling_macro v0.14.3
  Downloaded darling_core v0.14.3
  Downloaded futures v0.3.26
  Downloaded futures-channel v0.3.26
  Downloaded fastrand v1.9.0
  Downloaded derive_builder_macro v0.12.0
  Downloaded derive_builder_core v0.12.0
  Downloaded derive_builder v0.12.0
  Downloaded aho-corasick v0.7.20
  Downloaded futures-io v0.3.26
  Downloaded futures-task v0.3.26
  Downloaded glob v0.3.1
  Downloaded futures-macro v0.3.26
  Downloaded futures-core v0.3.26
  Downloaded crossbeam-epoch v0.9.14
  Downloaded crossbeam-deque v0.8.3
  Downloaded futures-executor v0.3.26
  Downloaded base64 v0.13.1
  Downloaded httparse v1.8.0
  Downloaded futures-util v0.3.26
  Downloaded http-range-header v0.3.0
  Downloaded io-lifetimes v1.0.5
  Downloaded h2 v0.3.16
  Downloaded either v1.8.1
  Downloaded http v0.2.9
  Downloaded thiserror-impl v1.0.38
  Downloaded thiserror v1.0.38
  Downloaded hyper v0.14.24
  Downloaded matchit v0.7.0
  Downloaded macro_rules_attribute-proc_macro v0.1.3
  Downloaded nu-ansi-term v0.46.0
  Downloaded onig v6.4.0
  Downloaded no-std-compat v0.4.1
  Downloaded native-tls v0.2.11
  Downloaded ipnet v2.7.1
  Downloaded memoffset v0.8.0
  Downloaded serde v1.0.152
  Downloaded paste v1.0.11
  Downloaded half v2.2.1
  Downloaded itoa v1.0.5
  Downloaded indexmap v1.9.2
  Downloaded futures-sink v0.3.26
  Downloaded num_cpus v1.15.0
  Downloaded async-trait v0.1.64
  Downloaded matrixmultiply v0.3.2
  Downloaded pin-project-internal v1.0.12
  Downloaded pkg-config v0.3.26
  Downloaded nom v7.1.3
  Downloaded tokio-native-tls v0.3.1
  Downloaded security-framework v2.8.2
  Downloaded ppv-lite86 v0.2.17
  Downloaded sync_wrapper v0.1.2
  Downloaded sharded-slab v0.1.4
  Downloaded rand_core v0.6.4
  Downloaded overload v0.1.1
  Downloaded rayon v1.6.1
  Downloaded find_cuda_helper v0.2.0
  Downloaded try-lock v0.2.4
  Downloaded axum-core v0.3.2
  Downloaded libc v0.2.139
  Downloaded tokio v1.25.0
  Downloaded onig_sys v69.8.1
  Downloaded encoding_rs v0.8.32
  Downloaded tracing-attributes v0.1.23
  Downloaded bytes v1.4.0
  Downloaded axum v0.6.9
  Downloaded 108 crates (8.8 MB) in 1.72s (largest was `encoding_rs` at 1.4 MB)
   Compiling proc-macro2 v1.0.51
   Compiling unicode-ident v1.0.6
   Compiling quote v1.0.23
   Compiling syn v1.0.109
   Compiling cfg-if v1.0.0
   Compiling autocfg v1.1.0
   Compiling libc v0.2.139
   Compiling memchr v2.5.0
   Compiling log v0.4.17
   Compiling once_cell v1.17.1
   Compiling pin-project-lite v0.2.9
   Compiling futures-core v0.3.26
   Compiling itoa v1.0.5
   Compiling bytes v1.4.0
   Compiling slab v0.4.8
   Compiling futures-channel v0.3.26
   Compiling futures-sink v0.3.26
   Compiling futures-task v0.3.26
   Compiling futures-util v0.3.26
   Compiling tracing-core v0.1.30
   Compiling futures-io v0.3.26
   Compiling pin-utils v0.1.0
   Compiling bitflags v1.3.2
   Compiling num_cpus v1.15.0
   Compiling tokio v1.25.0
   Compiling serde_derive v1.0.152
   Compiling crossbeam-utils v0.8.15
   Compiling socket2 v0.4.7
   Compiling mio v0.8.6
   Compiling serde v1.0.152
   Compiling fnv v1.0.7
   Compiling core-foundation-sys v0.8.3
   Compiling http v0.2.9
   Compiling memoffset v0.8.0
   Compiling io-lifetimes v1.0.5
   Compiling tower-service v0.3.2
   Compiling lazy_static v1.4.0
   Compiling rustversion v1.0.11
   Compiling crossbeam-epoch v0.9.14
   Compiling indexmap v1.9.2
   Compiling rustix v0.36.8
   Compiling strsim v0.10.0
   Compiling ident_case v1.0.1
   Compiling crossbeam-channel v0.5.7
   Compiling http-body v0.4.5
   Compiling errno v0.2.8
   Compiling ryu v1.0.12
   Compiling hashbrown v0.12.3
   Compiling scopeguard v1.1.0
   Compiling httparse v1.8.0
   Compiling core-foundation v0.9.3
   Compiling security-framework-sys v2.8.0
   Compiling tower-layer v0.3.2
   Compiling try-lock v0.2.4
   Compiling tinyvec_macros v0.1.1
   Compiling cc v1.0.79
   Compiling percent-encoding v2.2.0
   Compiling pkg-config v0.3.26
   Compiling serde_json v1.0.93
   Compiling either v1.8.1
   Compiling native-tls v0.2.11
   Compiling rayon-core v1.10.2
   Compiling fastrand v1.9.0
   Compiling form_urlencoded v1.1.0
   Compiling want v0.3.0
   Compiling tinyvec v1.6.0
   Compiling tempfile v3.4.0
   Compiling security-framework v2.8.2
   Compiling crossbeam-deque v0.8.3
   Compiling getrandom v0.2.8
   Compiling paste v1.0.11
   Compiling async-trait v0.1.64
   Compiling glob v0.3.1
   Compiling httpdate v1.0.2
   Compiling find_cuda_helper v0.2.0
   Compiling rand_core v0.6.4
   Compiling onig_sys v69.8.1
   Compiling axum-core v0.3.2
   Compiling minimal-lexical v0.2.1
   Compiling smallvec v1.10.0
   Compiling darling_core v0.14.3
   Compiling unicode-normalization v0.1.22
   Compiling esaxx-rs v0.1.8
   Compiling ppv-lite86 v0.2.17
   Compiling http-range-header v0.3.0
   Compiling unicode-bidi v0.3.10
   Compiling thiserror v1.0.38
   Compiling mime v0.3.16
   Compiling idna v0.3.0
   Compiling rand_chacha v0.3.1
   Compiling nom v7.1.3
   Compiling cudarc v0.7.5 (https://github.com/coreylowman/cudarc#6a01c299)
   Compiling rayon v1.6.1
   Compiling itertools v0.8.2
   Compiling axum v0.6.9
   Compiling aho-corasick v0.7.20
   Compiling base64 v0.13.1
   Compiling overload v0.1.1
   Compiling regex-syntax v0.6.28
   Compiling macro_rules_attribute-proc_macro v0.1.3
   Compiling unicode-segmentation v1.10.1
   Compiling nu-ansi-term v0.46.0
   Compiling futures-macro v0.3.26
   Compiling tokio-macros v1.8.2
   Compiling tracing-attributes v0.1.23
   Compiling darling_macro v0.14.3
   Compiling pin-project-internal v1.0.12
   Compiling thiserror-impl v1.0.38
   Compiling darling v0.14.3
   Compiling derive_builder_core v0.12.0
   Compiling pin-project v1.0.12
   Compiling regex v1.7.1
   Compiling tracing v0.1.37
   Compiling derive_builder_macro v0.12.0
   Compiling rayon-cond v0.1.0
   Compiling macro_rules_attribute v0.1.3
   Compiling url v2.3.1
error: failed to run custom build command for `cudarc v0.7.5 (https://github.com/coreylowman/cudarc#6a01c299)`

Caused by:
  process didn't exit successfully: `/Users/ondrej/repos/fast_gpt2/target/release/build/cudarc-d4bd025549ae484a/build-script-build` (exit status: 101)
  --- stderr
  thread 'main' panicked at 'Could not find a cuda installation', /Users/ondrej/.cargo/registry/src/github.com-1ecc6299db9ec823/find_cuda_helper-0.2.0/src/lib.rs:12:13
  note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
warning: build failed, waiting for other jobs to finish...

Do you require a GPU to run fast_gpt2?

Performance on Apple M1 Max

I am using the latest main (409c640) plus the following patch that make both PyTorch and fast_gpt2 run exactly the same model, and text (20 tokens), no Cuda in either:

diff --git a/src/lib.rs b/src/lib.rs
index 367e2ca..9eb9347 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -87,7 +87,7 @@ pub async fn run() -> Result<(), Gpt2Error> {
     #[cfg(not(feature = "dfdx"))]
     let gpt2 = Gpt2::from_tensors(&tensors, num_heads);
 
-    let string = "My name is";
+    let string = "Alan Turing theorized that computers would one day become very powerful, but even he could not imagine";
 
     let encoded = tokenizer.encode(string, false).unwrap();
     println!("Loaded & encoded {:?}", start.elapsed());
@@ -101,7 +101,7 @@ pub async fn run() -> Result<(), Gpt2Error> {
     let mut current_ids = ids.clone();
     #[cfg(feature = "cuda")]
     profiler_start()?;
-    for _i in 0..10 {
+    for _i in 0..20 {
         // println!("-------------");
         let start = std::time::Instant::now();
         let new_id = gpt2.forward(&current_ids, &mut past_key_values);
diff --git a/test.py b/test.py
index 608b4cf..5405733 100644
--- a/test.py
+++ b/test.py
@@ -4,7 +4,7 @@ start = datetime.datetime.now()
 import torch
 
 print(f"Loaded torch {datetime.datetime.now() - start}")
-torch.zeros((2, 2)).cuda()
+torch.zeros((2, 2))
 print(f"Loaded torch (cuda) {datetime.datetime.now() - start}")
 
 
@@ -13,12 +13,12 @@ from transformers import pipeline
 print(f"Loaded transformers {datetime.datetime.now() - start}")
 
 
-pipe = pipeline(task="text-generation", model="gpt2-large", do_sample=False, device=0)
-pipe.model.config.max_length = None
+pipe = pipeline(task="text-generation", model="gpt2", do_sample=False)
+#pipe.model.config.max_length = None
 print(f"Loaded in {datetime.datetime.now() - start}")
 inf_start = datetime.datetime.now()
-new_tokens = 10
-out = pipe("My name is", max_length=3 + new_tokens)
+new_tokens = 20
+out = pipe("Alan Turing theorized that computers would one day become very powerful, but even he could not imagine", max_new_tokens=new_tokens)
 print(f"Tokens: {(datetime.datetime.now() - inf_start)/new_tokens}/tokens")
 print(f"Inference took: {(datetime.datetime.now() - inf_start)}")
 print(out)

Here is what I got for fast_gpt2:

$ cargo run --example run --release    
    Finished release [optimized] target(s) in 0.11s
     Running `target/release/examples/run`
Safetensors 1.86ms
Tokenizer 31.226958ms
Loaded & encoded 461.879041ms
Loop in 156.600333ms
Loop in 80.137333ms
Loop in 80.596916ms
Loop in 81.4075ms
Loop in 79.844708ms
Loop in 81.373583ms
Loop in 82.741458ms
Loop in 107.9175ms
Loop in 83.611083ms
Loop in 80.898125ms
Loop in 84.577875ms
Loop in 84.253166ms
Loop in 84.087083ms
Loop in 85.110708ms
Loop in 85.1405ms
Loop in 84.291708ms
Loop in 84.722125ms
Loop in 84.515916ms
Loop in 84.030916ms
Loop in 84.704333ms
Result Ok("Alan Turing theorized that computers would one day become very powerful, but even he could not imagine how they would be able to do so.\n\n\"I think that the most important thing is")
Total Inference 2.222943541s

And PyTorch (installed from conda-forge):

$ TRANSFORMERS_OFFLINE=1 python test.py
Loaded torch 0:00:00.359938
Loaded torch (cuda) 0:00:00.360043
Loaded transformers 0:00:02.340165
Loaded in 0:00:04.140099
/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/transformers/generation/utils.py:1186: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use a generation configuration file (see https://huggingface.co/docs/transformers/main_classes/text_generation)
  warnings.warn(
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Tokens: 0:00:00.040217/tokens
Inference took: 0:00:00.804370
[{'generated_text': 'Alan Turing theorized that computers would one day become very powerful, but even he could not imagine how they would be able to do so.\n\n"I think that the most important thing is'}]
Ran in 0:00:04.944507

So fast_gpt2 runs in 2.2s, and PyTorch in 0.8s.

In order to speedup fast_gpt2, we can use the fast matrix matrix multiply from the Accelerate library, as shown in #10 (comment).

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.