Coder Social home page Coder Social logo

Comments (4)

Lucaweihs avatar Lucaweihs commented on September 22, 2024

Hi @Hustwireless,

Thanks for catching this! We've recently updated a few model architectures which made the "rnn" tag out of date for this model. I've just pushed a commit that should fix this error, let me know if this works for you.

from allenact.

Hustwireless avatar Hustwireless commented on September 22, 2024

Hi @Lucaweihs,

Thanks for the fix! It works now!

Since I've just started to use this framework, many things are not quite clear to me. I'm a little bit curious about why adding this rollout_source makes it work? Could you briefly point out what is the functionality of these tags in rollout_source?

Much appreciated!

from allenact.

Lucaweihs avatar Lucaweihs commented on September 22, 2024

Hi @Hustwireless,

Sure thing, for reference, here's the piece of code from the experiment configuration file.

self.viz = VizSuite(
    episode_ids=self.viz_ep_ids,
    mode=mode,
    # Basic 2D trajectory visualizer (task output source):
    base_trajectory=TrajectoryViz(
        path_to_target_location=("task_info", "target",),
    ),
    # Egocentric view visualizer (vector task source):
    egeocentric=AgentViewViz(
        max_video_length=100, episode_ids=self.viz_video_ids
    ),
    # Default action probability visualizer (actor critic output source):
    action_probs=ActorViz(figsize=(3.25, 10), fontsize=18),
    # Default taken action logprob visualizer (rollout storage source):
    taken_action_logprobs=TensorViz1D(),
    # Same episode mask visualizer (rollout storage source):
    episode_mask=TensorViz1D(rollout_source=("masks",)),
    # Default recurrent memory visualizer (rollout storage source):
    rnn_memory=TensorViz2D(rollout_source=("memory", "single_belief")),
    # Specialized 2D trajectory visualizer (task output source):
    thor_trajectory=ThorViz(
        figsize=(16, 8),
        viz_rows_cols=(448, 448),
        scenes=("FloorPlan_Train{}_{}", 1, 1, 1, 1),
    ),
)

What this piece of code is doing is instantiating a class that will handle visualizing various metrics during training (in particular, saving these visualizations to a tensorboard log). For instance, the thor_trajectory=ThorViz(...) code will result in generating a top-down visualizations of agent's trajectory (see the visualizations with "trajectory" in their label at the bottom of the tutorial).

Now the piece of code that was causing the problem was rnn_memory=TensorViz2D() which is meant to (1) take the hidden belief state from the agent (i.e. it's representation of the environment, in this case the 512-dimensional output from agent's GRU) at every step in an episode, (2) concatenate all of these hidden states into a T x 512 dimensional matrix (where T is the number of steps the agent took in an episode), and then (3) creates a heatmap from this matrix. This allows you to get a sense of how the hidden state of the agent changes during training (e.g. see the four heatmaps at the bottom of the tutorial with label test/memory/rnn_group0).

Now to be able to get the belief state from the agent to the visualizer during training, we need to tell the visualizer where to look for it. What adding rollout_source=("memory", "single_belief") is doing is telling the visualizer that it should look into the agent's rollout (i.e. just the history of its state/actions) and pick out the "single_belief" key from the agent's "memory". The reason this code broke is that the architecture we use for this task (see here) has changed and "single_belief" used to be called "rnn".

This type of visualization code is definitely an "advanced" topic in AllenAct, even I generally just use the default tensorboard graphs that are generated without specifying any custom visualizers.

Let me know if that helps or if you have any other questions.

from allenact.

Hustwireless avatar Hustwireless commented on September 22, 2024

Hi @Lucaweihs, thanks for this detailed walk through, it's super clear and helpful!

from allenact.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.