Copyright 2024 The Treescope Authors.

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.


Open in Colab Open in Kaggle

Getting Started: Pretty-Printing With Treescope#

Treescope is an interactive, color-coded HTML pretty-printer, designed for use in IPython notebooks. It’s designed to show you the structure of any model or tree of arrays, and is especially suited to looking at nested data structures.

As its name suggests, treescope is specifically focused on inspecting treelike data, represented as nodes (Python objects) that contain collections of child nodes (other Python objects). This is pretty similar to the behavior of the ordinary Python repr, which produces a flat source-code-like view of an object and its contents.

(Treescope does support more general Python reference graphs and cyclic references as well, but it always renders them in a tree-like form.)

Setup#

Let’s start by setting up the environment.

Imports#

To run this notebook, you need a Python environment with treescope and its dependencies installed.

In Colab or Kaggle, you can install it using the following command:

try:
  import treescope
except ImportError:
  !pip install treescope
from __future__ import annotations

import typing
from typing import Any

import dataclasses

import jax
import jax.numpy as jnp
import numpy as np
import torch

import IPython
import treescope

Overview of Treescope#

How does treescope work in practice? Here’s an example. Ordinarily, if you try to inspect a nested object containing NDArrays, you get something pretty hard to interpret. For instance, here’s a dictionary of arrays rendered using the default IPython pretty-printer:

some_arrays = {
    f"array_{i}": jax.random.normal(jax.random.key(i), (20, 50))
    for i in range(10)
}
some_arrays

And here’s how it looks in treescope:

with treescope.active_autovisualizer.set_scoped(treescope.ArrayAutovisualizer()):
  treescope.display(some_arrays)

Treescope renders this object as a syntax-highlighted structure that can be interactively folded and unfolded.

(Try clicking any marker to expand a level of the tree, or any marker to collapse a level.)

Let’s register treescope as the default pretty-printer for IPython. This is the recommended way to use treescope in an interactive setting. Treescope is designed to be a drop-in replacement for the ordinary IPython pretty-printer, so you should be able to start using it right away.

treescope.basic_interactive_setup()

Foldable and unfoldable nested objects#

Treescope lets you expand and collapse any level of your tree, so you can look at the parts you care about. In treescope, you can collapse or expand any object that would render as multiple lines (even if treescope doesn’t recognize the type!)

import dataclasses

@dataclasses.dataclass
class MyDataclass:
  a: Any
  b: Any
  c: Any

class TheZenOfPython:
  def __repr__(self):
    return "<The Zen of Python:\nBeautiful is better than ugly.\nExplicit is better than implicit.\nSimple is better than complex.\nComplex is better than complicated.\nFlat is better than nested.\nSparse is better than dense.\nReadability counts.\nSpecial cases aren't special enough to break the rules.\nAlthough practicality beats purity.\nErrors should never pass silently.\nUnless explicitly silenced.\nIn the face of ambiguity, refuse the temptation to guess.\nThere should be one-- and preferably only one --obvious way to do it.\nAlthough that way may not be obvious at first unless you're Dutch.\nNow is better than never.\nAlthough never is often better than *right* now.\nIf the implementation is hard to explain, it's a bad idea.\nIf the implementation is easy to explain, it may be a good idea.\nNamespaces are one honking great idea -- let's do more of those!>"
[
    MyDataclass('a' * i, 'b' * i, ('cccc\n') * i)
    for i in range(10)
] + [
    MyDataclass(TheZenOfPython(), TheZenOfPython(), TheZenOfPython())
]

Copyable key paths#

Want to pull out an object deep inside a tree? You can click the icon next to any subtree to copy a function that accesses that subtree, as Python source code. You can then paste it into a code cell and use it to pull out the subtree you wanted.

Try it on one of the parameters in the visualization below! (If you run this notebook yourself, you should be able to copy paths with one click. If you are viewing this notebook on Colab without running it, you’ll need to click and then copy the path manually due to Colab’s security restrictions.)

some_arrays
# for example
some_arrays['array_5']

Copyable code and roundtrip mode#

Treescope follows the same conventions as Python’s repr, whose documentation says:

For many types, this function makes an attempt to return a string that would yield an object with the same value when passed to eval(); otherwise, the representation is a string enclosed in angle brackets that contains the name of the type of the object together with additional information often including the name and address of the object

Accordingly, most of the output of treescope is valid Python syntax, and extra annotations are either hidden from selection or represented as Python comments.

Unfortunately, this isn’t always enough to rebuild the object, since it doesn’t tell you where custom types were defined. For instance, the rendering of JAX’s ShapeDtypeStruct doesn’t show where it was defined:

my_struct = jax.ShapeDtypeStruct(shape=(20, 10), dtype=jnp.float32)
my_struct

You can fix this by running treescope in “roundtrip mode”. By convention, this

  • adds qualified names to all types,

  • adds angle brackets (< and >) around parts of the rendering that look like valid Python syntax but don’t actually rebuild the object.

To toggle roundtrip mode, click on any output of treescope and press the “r” key. (Try it above!) Alternatively, pass roundtrip_mode=True to the renderer:

treescope.display(my_struct, roundtrip_mode=True)

Function reflection and canonical aliases#

Treescope has support for rendering useful information about functions and closures. The repr for functions isn’t always very helpful, especially if wrapped by JAX:

repr(jax.nn.relu)

Treescope tries to figure out where functions, function-like objects, and other constants are defined, and uses that to summarize them when collapsed. This works for ordinary function definitions defined anywhere and also for function-like objects in the JAX public API (see well_known_aliases.py)

jax.nn.relu

For ordinary functions, it can even identify the file where the function was defined:

jnp.sum

This works even for locally-defined notebook functions:

def my_function():
  print("hello world!")
my_function

Array visualizer#

Treescope includes a custom interactive NDArray visualizer designed to visualize the elements of high-dimensional arrays:

arr = (
    np.linspace(-10, 10, 20)
    * np.linspace(-10, 10, 15)[:, np.newaxis]
    * np.linspace(-1, 1, 5)[:, np.newaxis, np.newaxis]
)
treescope.render_array(arr)

It’s integrated with the rest of treescope, making it possible to directly visualize entire nested containers of arrays at once. Because we ran basic_interactive_setup above, arrays will be automatically visualized:

some_arrays['array_1']

You can also customize this in a given cell using a context manager:

with treescope.active_autovisualizer.set_scoped(
    treescope.ArrayAutovisualizer(maximum_size=100)
):
  treescope.display(some_arrays['array_1'])

Or the %%autovisualize magic:

%%autovisualize treescope.ArrayAutovisualizer(maximum_size=100)
treescope.display(some_arrays['array_1'])
%%autovisualize False
# ^ to turn it off
treescope.display(some_arrays['array_1'])

Customizable figure inlining#

If you want more control over how arrays and other objects are visualized, you can write your own visualization function and configure treescope to use it:

# You can use most rich display objects, for instance a plotly figure:
import plotly.io
import plotly.express as px

# Treescope uses similar embed settings as Colab, so configure it to render
# like in colab:
plotly.io.renderers.default = "colab"

def visualize_with_histograms(value, path):
  if isinstance(value, (np.ndarray, jax.Array)):
    return treescope.IPythonVisualization(
        px.histogram(
            value.flatten(),
            width=400, height=200
        ).update_layout(
            margin=dict(l=20, r=20, t=20, b=20)
        )
    )
with treescope.active_autovisualizer.set_scoped(visualize_with_histograms):
  treescope.display(some_arrays)

See the separate “Building Custom Visualizations” tutorial for more info on how to customize automatic visualization!

Where you can use treescope#

In IPython / Colab#

Treescope works great in IPython and Colab notebooks, and is designed as a drop-in replacement for the IPython pretty-printer.

We’ve already done it above, but you can configure treescope as the default IPython formatter by calling

treescope.register_as_default()
# ^ called by default when running treescope.basic_interactive_setup()

or manually display specific objects with

treescope.display(["some object"])

There’s also a helper function to show multiple objects with syntax similar to Python’s print:

treescope.show("A value:", ["some object"])

If you register treescope as the default IPython formatter, you can also just do

["some object"]

In ordinary Python for offline viewing#

Treescope can render directly to static HTML, without requiring any dynamic communication between the Python kernel and the HTML renderer. This means you can directly save the output of a treescope rendering to an HTML file, and open it later to view whatever was formatted.

with treescope.active_autovisualizer.set_scoped(treescope.ArrayAutovisualizer()):
  contents = treescope.render_to_html(some_arrays)

with open("/tmp/treescope_output.html", "w") as f:
  f.write(contents)

# Uncomment to download the file:
# import google.colab.files
# google.colab.files.download("/tmp/treescope_output.html")

Things treescope can render#

Treescope has support for a large number of common Python objects.

Dicts, lists, tuples, and sets#

[
    [(), (1,), (1, 2, 3)],
    {"foo": "bar", "baz": "qux"},
    {(1,2,3):(4,5,6), (7,8,9):(10,11,12)},
    {"a", "b", "c", "d"}
]

Builtins and literals#

(with special handling for multiline strings)

[
    [1, 2, 3, 4],
    ["a", "b", "c", "d"],
    [True, False, None, NotImplemented, Ellipsis],
    ["a\n  multiline\n    string"]
]

Dataclasses and namedtuples#

class Foo(typing.NamedTuple):
  a: int
  b: str

Foo(a=1, b="bar")
@dataclasses.dataclass(frozen=True)
class Bar:
  c: str
  d: int
  some_list: list = dataclasses.field(default_factory=list)

IPython.display.display(Bar(c="bar", d=2))

In roundtrip mode, treescope will even help you rebuild dataclasses with weird __init__ methods:

@dataclasses.dataclass
class WeirdInitClass:
  foo: int

  def __init__(self, half_foo: int):
    self.foo = 2 * half_foo

# This shows as WeirdInitClass(foo=4):
treescope.display(WeirdInitClass(2))

# But in roundtrip mode (explicit or after pressing `r`), it shows as
#   pz.dataclass_from_attributes(WeirdInitClass, foo=4)
# which bypasses __init__ and rebuilds the dataclass's attributes directly,
# since __init__ doesn't take `foo` as an argument.
treescope.display(WeirdInitClass(2), roundtrip_mode=True)

Multidimensional arrays / tensors#

Treescope supports showing a variety of arrays and tensors, including Numpy, JAX, and PyTorch arrays. It shows them by summarizing the shape, mean, standard deviation, bounds, and number of special values, and will also visualize their contents automatically (if automatic visualization is enabled).

[
    jnp.arange(1000),
    np.array([[np.nan] * 100, [0] * 50 + [1] * 50]),
    torch.linspace(-10, 20, 50),
]

Treescope also supports arrays with named axes. For example, when visualizing a Penzai NamedArray, axis names are automatically shown:

from penzai import pz
pz.nx.wrap(
    jax.random.normal(jax.random.key(1), (10, 4, 16)),
).tag("query_seq", "heads", "embed")

When used in IPython, Treescope will try to render the tree structure first and then insert array visualizations later. This can make visualization faster and can sometimes let you see the shape of JAX arrays before JAX has finished computing their values.

Neural network models#

Treescope can render a variety of neural network models.

In the Penzai and Equinox neural network libraries, model objects are represented as Python dataclasses. In this case, Treescope will render the dataclass attributes, similar to the ordinary repr for a dataclass. For instance, here’s a simple Penzai model:

from penzai.models import simple_mlp
simple_mlp.MLP.from_config(
    name="mlp",
    init_base_rng=jax.random.key(0),
    feature_sizes=[8, 32, 32, 8]
)

In PyTorch, models are represented as dynamic Python objects which can have arbitrary attributes. Treescope will inspect the model object and visualize its non-private configuration attributes, submodules, and parameters:

torch.nn.Sequential(
    torch.nn.Linear(8, 32),
    torch.nn.ReLU(),
    torch.nn.Linear(32, 32),
    torch.nn.ReLU(),
    torch.nn.Linear(32, 8),
)

Custom classes can also use Treescope’s extension API to modify their renderings. For instance, Flax’s NNX modules support the extension API and can also be used with Treescope.

Functions#

(As discussed in the features section)

[
    jnp.sum,
    dataclasses.dataclass,
    lambda x: x + 2,
    jax.vmap(lambda x: x),
]

Arbitrary PyTree types#

Treescope uses a fallback rendering strategy to show the children of any PyTree type registered with JAX, even if it isn’t usually supported by treescope.

jax.tree_util.Partial(lambda x, y, z: x + y, 10, y=100)

Partial support: Repeated Python object references#

Treescope will warn you if it sees multiple references to the same mutable object, since that can cause unexpected behavior. (In this case, copying the output won’t copy the shared reference structure.)

my_shared_list = []

{
    "foo": my_shared_list,
    "bar": my_shared_list,
    "baz": [1, 2, my_shared_list]
}