$ python gpt2.py "Alan Turing theorized that computers would one day become" -n 8
generating: 100%|โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ| 8/8 [00:03<00:00, 2.44it/s]
the most powerful machines on the planet.
$ python gpt2.py "Alan Turing theorized that computers would one day become" -n 8
generating: 0%| | 0/8 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/Users/ondrej/repos/picoGPT/gpt2.py", line 121, in <module>
fire.Fire(main)
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/fire/core.py", line 141, in Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/fire/core.py", line 475, in _Fire
component, remaining_args = _CallAndUpdateTrace(
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
component = fn(*varargs, **kwargs)
File "/Users/ondrej/repos/picoGPT/gpt2.py", line 110, in main
output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate)
File "/Users/ondrej/repos/picoGPT/gpt2.py", line 92, in generate
inputs = np.append(inputs, [next_id]) # append prediction to input
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 163, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/api.py", line 694, in cache_miss
execute = dispatch._xla_call_impl_lazy(fun_, *tracers, **params)
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/dispatch.py", line 240, in _xla_call_impl_lazy
return xla_callable(fun, device, backend, name, donated_invars, keep_unused,
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/linear_util.py", line 301, in memoized_fun
ans = call(fun, *args)
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/dispatch.py", line 351, in _xla_callable_uncached
computation = sharded_lowering(fun, device, backend, name, donated_invars,
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/dispatch.py", line 342, in sharded_lowering
return pxla.lower_sharding_computation(
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 2797, in lower_sharding_computation
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 2073, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 2006, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/linear_util.py", line 165, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 2802, in append
return concatenate([ravel(arr), ravel(values)], 0)
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 163, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/api.py", line 698, in cache_miss
top_trace.process_call(primitive, fun_, tracers, params))
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1747, in process_call
jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(f, self.main, debug_info=dbg)
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 2035, in trace_to_subjaxpr_dynamic2
ans = fun.call_wrapped(*in_tracers_)
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/linear_util.py", line 165, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 812, in ravel
_stackable(a) or _check_arraylike("ravel", a)
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/numpy/util.py", line 345, in _check_arraylike
raise TypeError(msg.format(fun_name, type(arg), pos))
jax._src.traceback_util.UnfilteredStackTrace: TypeError: ravel requires ndarray or scalar arguments, got <class 'list'> at position 0.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/ondrej/repos/picoGPT/gpt2.py", line 121, in <module>
fire.Fire(main)
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/fire/core.py", line 141, in Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/fire/core.py", line 475, in _Fire
component, remaining_args = _CallAndUpdateTrace(
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
component = fn(*varargs, **kwargs)
File "/Users/ondrej/repos/picoGPT/gpt2.py", line 110, in main
output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate)
File "/Users/ondrej/repos/picoGPT/gpt2.py", line 92, in generate
inputs = np.append(inputs, [next_id]) # append prediction to input
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 2802, in append
return concatenate([ravel(arr), ravel(values)], 0)
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 812, in ravel
_stackable(a) or _check_arraylike("ravel", a)
File "/Users/ondrej/mambaforge/envs/pico/lib/python3.9/site-packages/jax/_src/numpy/util.py", line 345, in _check_arraylike
raise TypeError(msg.format(fun_name, type(arg), pos))
TypeError: ravel requires ndarray or scalar arguments, got <class 'list'> at position 0.
$ conda env export
name: pico
channels:
- conda-forge
dependencies:
- appdirs=1.4.4=pyh9f0ad1d_0
- brotlipy=0.7.0=py39h02fc5c5_1005
- bzip2=1.0.8=h3422bc3_4
- c-ares=1.18.1=h3422bc3_0
- ca-certificates=2022.12.7=h4653dfc_0
- cffi=1.15.1=py39h7e6b969_3
- cryptography=39.0.1=py39he2a39a8_0
- idna=3.4=pyhd8ed1ab_0
- jax=0.4.3=pyhd8ed1ab_0
- jaxlib=0.4.3=cpu_py39h99d3290_1
- libabseil=20220623.0=cxx17_h28b99d4_6
- libblas=3.9.0=16_osxarm64_openblas
- libcblas=3.9.0=16_osxarm64_openblas
- libcxx=14.0.6=h2692d47_0
- libffi=3.4.2=h3422bc3_5
- libgfortran=5.0.0=11_3_0_hd922786_27
- libgfortran5=11.3.0=hdaf2cc0_27
- libgrpc=1.51.1=hb15be72_1
- liblapack=3.9.0=16_osxarm64_openblas
- libopenblas=0.3.21=openmp_hc731615_3
- libprotobuf=3.21.12=hb5ab8b9_0
- libsqlite=3.40.0=h76d750c_0
- libzlib=1.2.13=h03a7124_4
- llvm-openmp=15.0.7=h7cfbb63_0
- ncurses=6.3=h07bb92c_1
- openssl=3.0.8=h03a7124_0
- opt_einsum=3.3.0=pyhd8ed1ab_1
- packaging=23.0=pyhd8ed1ab_0
- pip=23.0=pyhd8ed1ab_0
- pooch=1.6.0=pyhd8ed1ab_0
- pycparser=2.21=pyhd8ed1ab_0
- pyopenssl=23.0.0=pyhd8ed1ab_0
- pysocks=1.7.1=pyha2e5f31_6
- python=3.9.16=hea58f1e_0_cpython
- python_abi=3.9=3_cp39
- re2=2023.02.01=hb7217d7_0
- readline=8.1.2=h46ed386_0
- scipy=1.10.0=py39h18313fe_2
- setuptools=67.1.0=pyhd8ed1ab_0
- tk=8.6.12=he1e0b03_0
- tzdata=2022g=h191b570_0
- urllib3=1.26.14=pyhd8ed1ab_0
- wheel=0.38.4=pyhd8ed1ab_0
- xz=5.2.6=h57fd34a_0
- zlib=1.2.13=h03a7124_4
- pip:
- absl-py==1.4.0
- astunparse==1.6.3
- cachetools==5.3.0
- certifi==2022.12.7
- charset-normalizer==2.0.12
- fire==0.5.0
- flatbuffers==23.1.21
- gast==0.4.0
- google-auth==2.16.0
- google-auth-oauthlib==0.4.6
- google-pasta==0.2.0
- grpcio==1.51.1
- h5py==3.8.0
- importlib-metadata==6.0.0
- keras==2.11.0
- libclang==15.0.6.1
- markdown==3.4.1
- markupsafe==2.1.2
- numpy==1.24.1
- oauthlib==3.2.2
- protobuf==3.19.6
- pyasn1==0.4.8
- pyasn1-modules==0.2.8
- regex==2017.4.5
- requests==2.27.1
- requests-oauthlib==1.3.1
- rsa==4.9
- six==1.16.0
- tensorboard==2.11.2
- tensorboard-data-server==0.6.1
- tensorboard-plugin-wit==1.8.1
- tensorflow-estimator==2.11.0
- tensorflow-macos==2.11.0
- termcolor==2.2.0
- tqdm==4.64.0
- typing-extensions==4.4.0
- werkzeug==2.2.2
- wrapt==1.14.1
- zipp==3.13.0
prefix: /Users/ondrej/mambaforge/envs/pico