Coder Social home page Coder Social logo

callummcdougall / sae_vis Goto Github PK

View Code? Open in Web Editor NEW
115.0 115.0 28.0 7.25 MB

Create feature-centric and prompt-centric visualizations for sparse autoencoders (like those from Anthropic's published research).

License: MIT License

Jupyter Notebook 1.13% Python 0.98% CSS 0.03% JavaScript 0.14% HTML 97.71% Makefile 0.01%

sae_vis's People

Contributors

afspies avatar arthurconmy avatar callummcdougall avatar chanind avatar hijohnnylin avatar jbloomaus avatar jordansauce avatar lewington-pitsos avatar lucyfarnik avatar shehper avatar wllgrnt avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

sae_vis's Issues

load a local model

How can I load a local model instead of downloading it through huggingface? Is there any sample code?

Prompt-centric vis gone wild.

Hi, i have a question regarding the work of the prompt-centric visualiser, it seems that there is some issues with the code... or maybe I am doing it wrong.

Here it is:

prompt = <Some long text> 
filename = "_prompt_vis_demo.html"
sae_vis_data.save_prompt_centric_vis(
    prompt = prompt,
    filename = filename
)

I use the code above for visualisation. However, no matter the model I use and prompt I evaluate it always returns the following error:

AssertionError: Key not found in scores_dict.keys()=dict_keys([]).
This means that there are no features with a nontrivial score for this choice of key & metric.

Even though when I use a pretrained model from here SAELens demo

I thought it happens because my SAE is too sparse, though I did not seem like that, the last time I checked.
I would really appreciate if you could at least lead me somewhere with this issue.
Thanks in advance.

Support Attention Output (hook_z) SAEs + DFA by source position

Currently the library doesn't support Attention Output (hook_z) SAEs. I personally use these a ton (and know of a few other groups working with them), and it would be great to just use sae_vis out of the box! I think this would be an easy change.

Relatedly, would be great to support DFA by source position for the hook_z dashboards, as this makes interpreting attention output features way easier. Example: Induction features are tricky to spot with max activating examples, but obvious with DFA.

Move to pyproject.toml for packaging / dependencies

Pyproject.toml is the recommended way handle packaging and dependencies in modern python (more info). The easiest way to manage this is to use a Python dependency manager like Poetry or PDM. Both of these projects can handle dependency management and building/packaging/publishing python projects.

Set up auto-deploy action

There's two common ways this can work. One is using semantic-release, where adding keywords into the commit message is used to determine how the version number should be bumped and automatically generate a changelog, and is usually my preferred style. Another style is to manually cut a release tag on Github, and trigger a push to PyPI when that happens. This is more manual, but also gives more explicit control over the process to the owner of the repo.

Setup tooling for running automated tests

Pytest is the default choice for automated testing in Python and would be good to set up for this project. This would also entail setting up a github action to run tests automatically on every commit / PR

Need support for gated SAE

AssertionError: If encoder isn't an AutoEncoder, it should have weights 'W_enc', 'W_dec', 'b_enc', 'b_dec'

Gated SAE do not have b_enc and it seems AutoEncoder is not suitable for gated SAE.

Circular dependency between SAELens and sae_vis

sae_vis depends on SAELens in dependencies, but SAELens also depends on sae_vis. This doesn't seem like a stable situation and is likely to cause issues.

It looks like the SAELens depedency in sae_vis is only there for the demo.ipynb. Since this isn't a core dependency, I would recommend removing it from the main dependency list.

One option is to create and extras group called demo or something similar, so that it's possible to install pip install sae-vis[demo] for the demo requirements, but these won't be installed by default. More info on extras is here: https://python-poetry.org/docs/pyproject/#extras

Another possible solution is to just add !pip install sae-lens to the top of the demo.ipynb file, so the dependency is self-contained there.

Set up tooling for linting and auto-formatting

For new projects, the common setup for this is Ruff as it can handle both linting and formatting and is faster than tools written in Python. Alternatively, Black and Flake8 can be used individually. This would also entail setting a Github action to check linting and formatting on each commit / PR.

Activation Sequence shows up in the wrong "group" (SequenceGroupData)

Issue:
Sae_vis returns activation texts in "groups" according to what quantile they're in, or if they're in the top activating group. The problem is sometimes it will return an activation text in the wrong group. I'm able to reproduce an issue where it puts an activation text with max act of 2.88 into a group that is supposed to be range 0.000 to 0.578 (this is testing res-jb).

Code Details:
SaeVisData returns FeatureDatas. Each FeatureData has sequence_data (a SequenceMultiGroupData), which itself has seq_group_data (an array of SequenceGroupData). The SequenceGroupData has a title, which contains the Activation Group information (eg "INTERVAL min_interval to max_interval CONTAINS percent%"). In this case, the top activating token of an activation text is outside of the "INTERVAL min_interval to max_interval".

Reproduction:

poetry shell
poetry neuronpedia.py generate

source set id: [enter any source set name]
SAE path: [enter path to the 0-res-jb SAE]
sparsity threshold: -5
features per batch: 20
batches to sample from: 4096
prompts to select from: 24576
resume from: 1

Example incorrect output:
The attached example.json is 0-res-jb, feature index 18, and I've removed all activations except the bugged activation. Line 308 and 309 show a "binMin" (group min_interval) of 0 and "binMax" of 0.578, but line 485 and 570 show that this text has a max activating token of 2.88. I also manually tested the text to ensure that the max activating value of 2.88 is indeed correct.

example.json

Proposal: change colour handling code to scale with the max activation in a prompt.

Currently, the background color logic sets all activations>1 to the maximally orange colour like so:

sae_vis/sae_vis/html_fns.py

Lines 156 to 157 in 2740c00

# ! Clip values in [0, 1] range (temporary)
bg_values = np.clip(feat_acts, 0, 1)

Wouldn't scaling the colours from 0 to the max activation in a prompt be generally better? Specifically, scale the feat_acts by

bg_values = np.maximum(feat_acts, 0) / max(1, np.max(feat_acts))

...so that we don't apply scaling to the cases where the max activation in the prompt is less than 1.0, which would be distracting as well as error when nothing fires in the prompt.

Docstring of `compute_feat_acts` doesn't match function args

I noticed feature_idx isn't explained in the docstring, and the docstring mentions feature_act_dir and feature_bias. Not a big issue (it seems obvious that feature_idx is the index of the SAE feature) but I thought I'd mention it.

def compute_feat_acts(
    model_acts: Float[Tensor, "batch seq d_in"],
    feature_idx: Int[Tensor, "feats"],
    encoder: AutoEncoder,
    encoder_B: Optional[AutoEncoder] = None,
    corrcoef_neurons: Optional[BatchedCorrCoef] = None,
    corrcoef_encoder_B: Optional[BatchedCorrCoef] = None,
) -> Float[Tensor, "batch seq feats"]:
    '''
    This function computes the feature activations, given a bunch of model data. It also updates the rolling correlation
    coefficient objects, if they're given.

    Args:
        model_acts: Float[Tensor, "batch seq d_in"]
            The activations of the model, which the SAE was trained on.
        feature_act_dir: Float[Tensor, "d_in feats"]
            The SAE's encoder weights for the feature(s) which we're interested in.
        feature_bias: Float[Tensor, "feats"]
            The bias of the encoder, which we add to the feature activations before ReLU'ing.
        encoder: AutoEncoder
            The encoder object, which we use to calculate the feature activations.
        encoder_B: Optional[AutoEncoder]
            The encoder-B object, which we use to calculate the feature activations.
        corrcoef_neurons: Optional[BatchedCorrCoef]
            The object which stores the rolling correlation coefficients between feature activations & neurons.
        corrcoef_encoder_B: Optional[BatchedCorrCoef]
            The object which stores the rolling correlation coefficients between feature activations & encoder-B features.
    '''
    ```

Use input tensor's device in some utils_fns, rather than utils_fns.device?

Hi Callum!
I'm looking at integrating your feature visualization tools into some of our SAE code with Apollo (primarily @danbraunai-apollo and @Stefan-Heimersheim). Is this cool with you? If so, how do you feel about us potentially contributing to this codebase, or forking it? (we can set up a meeting to discuss if you have time)

For now, I'm using some classes in utils_fns.py (TopK and QuantileCalculator), and I notice that they error when the device of the input tensor is not equal to utils_fns.device. I'm thinking it might be nice for these functions to base their device on the input tensors, while preserving functionality in the case where the input tensors are on that device. I would have submitted a PR but I don't have access:
Code: utils_fns.zip
image

Publish on PyPI

This is great, but would be helpful if it could be published on PyPI so it's easier to pin versions and track breaking changes

Set up tooling for type-checking

Currently, sae-vis has type hints added to code, which is great, but there's not a type-checking step in CI to validate that those types are correct. Pyright is probably the best choice for new projects, but MyPy is also a good choice. This would entail setting up a github actions step to validate that types are correct on every commit / PR as well.

Need support for byte-pair encoded utf-8 symbols

Currently, the tokens are decoded one by one and shown on the interface. However, for some utf-8 symbols, if the corresponding tokens are decoded one by one, the result will be some unreadable codes which can not be correctly shown on the panel.

image

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.