Coder Social home page Coder Social logo

axlearn's Introduction

The AXLearn Library for Deep Learning

This library is under active development and the API is subject to change.

Table of Contents

Section Description
Introduction What is AXLearn?
Getting Started Getting up and running with AXLearn.
Concepts Core concepts and design principles.
CLI User Guide How to use the CLI.
Infrastructure Core infrastructure components.

Introduction

AXLearn is a library built on top of JAX and XLA to support the development of large-scale deep learning models.

AXLearn takes an object-oriented approach to the software engineering challenges that arise from building, iterating, and maintaining models. The configuration system of the library lets users compose models from reusable building blocks and integrate with other libraries such as Flax and Hugging Face transformers.

AXLearn is built to scale. It supports the training of models with up to hundreds of billions of parameters across thousands of accelerators at high utilization. It is also designed to run on public clouds and provides tools to deploy and manage jobs and data. Built on top of GSPMD, AXLearn adopts a global computation paradigm to allow users to describe computation on a virtual global computer rather than on a per-accelerator basis.

AXLearn supports a wide range of applications, including natural language processing, computer vision, and speech recognition and contains baseline configurations for training state-of-the-art models.

Please see Concepts for more details on the core components and design of AXLearn, or Getting Started if you want to get your hands dirty.

axlearn's People

Contributors

alex8937 avatar altimofeev avatar amcw7777 avatar apghml avatar ethanlm avatar fnan avatar gyin94 avatar haijingfu avatar jianyuwangv avatar jiarui-lu2 avatar jinhaolei avatar kelvin-zou avatar madrob avatar markblee avatar qdavid1 avatar ruomingp avatar swiseman avatar taolei87 avatar tarangkhanna avatar tgunter avatar tombstone avatar tuzhucheng avatar weiliu89 avatar wwu137 avatar xianzhidu avatar ya5ut avatar yqwangustc avatar zbwglory avatar zhiyun avatar zhzhyi avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

axlearn's Issues

BUG: axlearn.common.utils.as_tensor calls .numpy() which doesn't work with bfloat16

I noticed this problem when I was developing Medusa+, which depends on axlearn. Here is the raw trace of the error when I am converting an Ajax tensor in bfloat16 to torch.tensor: https://rio.apple.com/projects/ai-medusa-plus/pipeline-specs/ai-medusa-plus-unit_tests/pipelines/f4345942-bddc-47ba-ba51-f3c1019290d3/log#L718-L738

The problem is that numpy does not support bfloat16, so if the source tensor is bfloat16, the call to .numpy() would assert.

The following script reveals this problem:

import torch
import numpy
import jax.numpy as jnp

x = torch.rand(
    (1,),
    dtype=torch.float32,
)
print(x.numpy())

x = torch.rand(
    (1,),
    dtype=torch.float16,
)
print(x.numpy())

x = torch.rand(
    (1,),
    dtype=torch.bfloat16,
)
print(x.numpy())

The first calls to .numpy() would succeed; however, the last would fail.

22:36 $ python3 medusa_plus/numpy_bf16.py
[0.686854]
[0.6177]
Traceback (most recent call last):
  File "/mnt/medusa-plus/medusa_plus/numpy_bf16.py", line 21, in <module>
    print(x.numpy())
TypeError: Got unsupported ScalarType BFloat16

configure gcloud when configuring axlearn

I ran axlearn gcp config activate to activate a project in us-west1. However, this doesn't change the config for my gcloud, as evident when I ran gcloud config list.

If I don't manually change the gcloud config, then I ran into this error:
E0319 20:55:05.981639 140440028389440 config.py:118] Unknown settings for project=<project> and zone=us-west1-a; You may want to configure this project first; Please refer to the docs for details.
When using this command: axlearn gcp dataflow start ...

axlearn on GPU started failing during init after upgrade

This is the error message I see when launching like this:

timeout -k 60s 900s python3 -m axlearn.common.launch_trainer_main --module=gke_fuji --config=fuji-7B-b512-fsdp8 --trainer_dir=/tmp/test_trainer --data_dir=gs://axlearn-public/tensorflow_datasets --jax_backend=gpu --num_processes=8 --distributed_coordinator=stoelinga-may13-1-j-0-0.stoelinga-may13-1 --process_id=0 --trace_at_steps=25

Error message:

2024-05-13 16:17:05.732984: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make
sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.ten
sorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/tmp/axlearn/axlearn/common/launch_trainer_main.py", line 16, in <module>
    app.run(main)
  File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/tmp/axlearn/axlearn/common/launch_trainer_main.py", line 10, in main
    launch.setup()
  File "/tmp/axlearn/axlearn/common/launch.py", line 92, in setup
    setup_spmd(
  File "/tmp/axlearn/axlearn/common/utils_spmd.py", line 118, in setup
    jax.distributed.initialize(**init_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/distributed.py", line 196, in initialize
    global_state.initialize(coordinator_address, num_processes, process_id,
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/distributed.py", line 72, in initialize
    default_coordinator_bind_address = '[::]:' + coordinator_address.rsplit(':', 1)[1]
IndexError: list index out of range

C4 Dataset Bucket has Permissions Issues

Hi, when using the default C4 dataset bucket I get this error as of today. It was working before today.
DATA_DIR="gs://axlearn-public/tensorflow_datasets"

File "/shared_new/ptoulme/axlearn/axlearn/axlearn/common/input_tf_data.py", line 277, in fn
    builder = tfds.builder(dataset_name, data_dir=data_dir)
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/shared_new/ptoulme/axlearn/venv/lib/python3.10/site-packages/tensorflow_datasets/core/logging/__init__.py", line 168, in __call__
    return function(*args, **kwargs)
  File "/shared_new/ptoulme/axlearn/venv/lib/python3.10/site-packages/tensorflow_datasets/core/load.py", line 210, in builder
    return read_only_builder.builder_from_files(str(name), **builder_kwargs)
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/shared_new/ptoulme/axlearn/venv/lib/python3.10/site-packages/tensorflow_datasets/core/read_only_builder.py", line 265, in builder_from_files
    builder_dir = _find_builder_dir(name, **builder_kwargs)
  File "/shared_new/ptoulme/axlearn/venv/lib/python3.10/site-packages/tensorflow_datasets/core/read_only_builder.py", line 334, in _find_builder_dir
    builder_dir = _find_builder_dir_single_dir(
  File "/shared_new/ptoulme/axlearn/venv/lib/python3.10/site-packages/tensorflow_datasets/core/read_only_builder.py", line 406, in _find_builder_dir_single_dir
    if _contains_dataset(dataset_dir):
  File "/shared_new/ptoulme/axlearn/venv/lib/python3.10/site-packages/tensorflow_datasets/core/read_only_builder.py", line 383, in _contains_dataset
    return feature_lib.make_config_path(dataset_dir).exists()
  File "/shared_new/ptoulme/axlearn/venv/lib/python3.10/site-packages/etils/epath/gpath.py", line 148, in exists
    return self._backend.exists(self._path_str)
  File "/shared_new/ptoulme/axlearn/venv/lib/python3.10/site-packages/etils/epath/backend.py", line 277, in exists
    return self.gfile.exists(path)
  File "/shared_new/ptoulme/axlearn/venv/lib/python3.10/site-packages/tensorflow/python/lib/io/file_io.py", line 290, in file_exists_v2
    _pywrap_file_io.FileExists(compat.path_to_bytes(path))
tensorflow.python.framework.errors_impl.InvalidArgumentError: Error executing an HTTP request: HTTP response code 400 with body '{
  "error": {
    "code": 400,
    "message": "Bucket is a requester pays bucket but no user project provided.",
    "errors": [
      {
        "message": "Bucket is a requester pays bucket but no user project provided.",
        "domain": "global",
        "reason": "required"
      }
    ]
  }
}
'
         when reading metadata of gs://axlearn-public/tensorflow_datasets/c4/en/3.0.1/features.json

The above exception was the direct cause of the following exception

Making core optional causing no module names absl error

Launching a tpu job on GKE threw this error in the container:

Traceback (most recent call last):
  File "/usr/local/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/local/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/root/axlearn/common/launch_trainer_main.py", line 5, in <module>
    from absl import app, flags
ModuleNotFoundError: No module named 'absl'

I believe this is due to this commit which made the core depencies optional: eb7d655

This is how I launch my GKE TPU job:

axlearn gcp gke start --instance_type=tpu-v5litepod-256 --num_replicas=1 \
        --cluster=snip-us-west4 \
        --queue=multislice-queue \
        --bundler_spec=allow_dirty=True \
        --bundler_type=artifactregistry --bundler_spec=image=tpu \
        --bundler_spec=dockerfile=Dockerfile --bundler_spec=target=tpu \
        -- python3 -m axlearn.common.launch_trainer_main \
        --module=text.gpt.c4_trainer --config=fuji-7B-v2 \
          --trainer_dir=gs://snip-multipod-axlearn \
          --data_dir=gs://axlearn-public/tensorflow_datasets  \
          --jax_backend=tpu

one fix that @jesus-orozco came up with is to use:

tpu = [
    "axlearn[core,gcp]",
    "jax[tpu]==0.4.28",  # must be >=0.4.19 for compat with v5p.
]

Missing googleapiclient as dependency

I get ModuleNotFoundError: No module named 'googleapiclient' when calling:

from axlearn.cloud.gcp.vm import ..

Seems that google-api-python-client needs to be added as a dependency in axlearn?

nodeSelector set by default requires tpu provisioner

These are the nodeSelectors that got added:

Node-Selectors:              cloud.google.com/gke-accelerator-count=4
                             cloud.google.com/gke-spot=true
                             cloud.google.com/gke-tpu-accelerator=tpu-v5-lite-podslice
                             cloud.google.com/gke-tpu-topology=16x16
                             provisioner-nodepool-id=stoelinga-8733bd

This was my launch job:

export BASTION_TIER=1
axlearn gcp gke start --instance_type=tpu-v5litepod-256 --num_replicas=1 \
        --cluster=v5e-256-bodaborg-us-west4 --bundler_spec=allow_dirty=True \
        --bundler_type=artifactregistry --bundler_spec=image=tpu \
        --bundler_spec=dockerfile=Dockerfile --bundler_spec=target=tpu \
        -- python3 -c "'import jax; print(jax.devices())'"

Expectation:
The job should not have this selector provisioner-nodepool-id=stoelinga-8733bd since that assumes the tpu provisioner is always used. This may not be the case for external users.

Unable to use TPU on GKE using on-demand quota

Currently axlearn either adds a nodeSelector for spot=true or it adds a nodeSelector for reservation:

        if tier == "0" and cfg.reservation is not None:
            logging.info("Found tier=%s in env. Using reservation=%s", tier, cfg.reservation)
            selector.update({"cloud.google.com/reservation-name": cfg.reservation})
        else:
            logging.info("Found tier=%s in env. Using spot quota", tier)
            selector.update({"cloud.google.com/gke-spot": "true"})
            tolerations.append(
                {
                    "key": "cloud.google.com/gke-spot",
                    "operator": "Equal",
                    "value": "true",
                    "effect": "NoSchedule",
                }
            )

It should be possible to launch a job using on-demand TPU, however today that's not possible unless you remove this line:

selector.update({"cloud.google.com/gke-spot": "true"})

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.