Copyright 2024 The Treescope Authors.

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.


Open in Colab Open in Kaggle

Visualizing Arrays with Treescope#

High-dimensional NDArray (or tensor) data is common in many machine learning settings, but most plotting libraries are designed for either 2D image data or 1D time series data. Treescope includes a powerful arbitrarily-high-dimensional-array visualizer designed to make it easy to quickly summarize NDArrays without having to write manual plotting logic.

This notebook is primarily written in terms of Numpy arrays, but it also works for other types of array, including JAX arrays, PyTorch tensors, and Penzai NamedArrays!

Setup#

To run this notebook, you need a Python environment with treescope and its dependencies installed.

In Colab or Kaggle, you can install it using the following command:

try:
  import treescope
except ImportError:
  !pip install treescope
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
from __future__ import annotations
from typing import Any

import jax
import jax.numpy as jnp
import numpy as np

import IPython
import treescope
treescope.register_autovisualize_magic()

Visualizing NDArrays with treescope.render_array#

Visualizing numeric data and customizing colormaps#

Arrays can be directly rendered using default settings by passing them to treescope.render_array:

help(treescope.render_array)
my_array = np.cos(np.arange(300).reshape((10,30)) * 0.2)

treescope.render_array(my_array)

Things to notice:

  • The visualization is interactive! (Try zooming in and out, hovering over the array to inspect individual elements, or clicking to remember a particular element.)

  • The shape of the array can be read off by looking at the axis labels.

  • Pixels are always square in arrayviz renderings. (In fact, they are always exactly 7 pixels by 7 pixels at zoom level 1.)

The default rendering strategy uses a diverging colormap centered at zero, with blue for positive numbers and red for negative ones, to show you the absolute magnitude and sign of the array. You can toggle to a relative mode by passing the argument around_zero=False:

treescope.render_array(my_array, around_zero=False)

You can also customize the upper and lower bounds of the colormap by passing vmin and/or vmax:

treescope.render_array(my_array, vmax=0.7)

In this case, the array has values outside of our specified colormap bounds; those out-of-bounds values are rendered with “+” and “-” to indicate that they’ve been clipped.

Since we didn’t pass around_zero=False, it automatically set vmin to -vmax for us. You can choose to set both explicitly too:

treescope.render_array(my_array, vmin=-0.1, vmax=0.7)

If you want to customize the way colors are rendered, you can pass a custom colormap as a list of (R, G, B) color tuples:

import palettable
treescope.render_array(my_array, colormap=palettable.matplotlib.Inferno_20.colors)
treescope.render_array(my_array, colormap=palettable.cmocean.sequential.Speed_20.colors)

Visualizing high-dimensional arrays#

So far we’ve been looking at an array with two axes, but Treescope’s array renderer works out-of-the-box with arbitrarily high-dimensional arrays as well:

my_4d_array = np.cos(np.arange(5*6*7*8).reshape((5,6,7,8)) * 0.1)
treescope.render_array(my_4d_array)

For high-dimensional arrays, the individual axis labels indicate which level of the plot corresponds to which axis. Above, each 7x8 square facet represents a slice my_4d_array[i,j,:,:], with individual pixels ranging along axis 2 and axis 3; this is denoted by the axis2 and axis3 labels for that facet. The six columns correspond to slices along axis 1, and the five rows correspond to slices along axis 0, as denoted by the outermost labels for those axes.

You can control which axes get assigned to which direction if you want, specified from innermost to outermost:

treescope.render_array(my_4d_array, columns=[2, 0, 1])

Note that the gap between the “axis0” groups is twice as large as the gap between “axis2” groups, so that you can visually distinguish the groups.

Treescope can also visualize Penzai’s NamedArray, and takes labels from them. This means that, if your code is written with NamedArrays, you get labeled visualizations for free!

from penzai import pz
my_named_array = pz.nx.wrap(
    jax.random.normal(jax.random.key(1), (10, 4, 16)),
).tag("query_seq", "heads", "embed")

treescope.render_array(my_named_array, columns=["embed", "heads"])

(See the NamedArray tutorial for more information on how NamedArrays work in penzai.)

Identifying extreme or invalid array values#

By default, Treescope’s array renderer tries to configure the colormap to show interesting detail, clipping outliers. Specifically, it limits the colormap to 3 standard deviations away from the mean (or, technically, from zero if around_zero is set):

my_outlier_array = np.cos(np.arange(300).reshape((10,30)) * 0.2)
my_outlier_array[4, 2] = 10.0
treescope.render_array(my_outlier_array)

It also annotates any invalid array values by drawing annotations on top of the visualization:

numerator = np.linspace(-5, 5, 31)
denominator = np.linspace(-1, 1, 13)
array_with_infs_and_nans = numerator[None, :] / denominator[:, None]
treescope.render_array(array_with_infs_and_nans)

Above, “I” (white on a blue background) denotes positive infinity, “-I” (white on a red background) denotes negative infinity, and “X” (in magenta on a black background) denotes NaN. (You can also see the outlier-clipping behavior clipping a few of the largest finite values here.)

This works in relative mode too:

treescope.render_array(array_with_infs_and_nans, around_zero=False)

If you want, you can mask out data by providing a “valid mask”. Only values where the mask is True will be rendered; masked-out data is shown in gray with black dots.

valid_mask = np.isfinite(array_with_infs_and_nans) & (np.abs(array_with_infs_and_nans) < 10)
treescope.render_array(
    array_with_infs_and_nans,
    valid_mask=valid_mask,
)

Visualizing categorical data#

Treescope’s array renderer also supports rendering categorical data, even with very high numbers of categories. Data with a discrete (integer or boolean) dtype is rendered as categorical by default, with different colors for different categories:

treescope.render_array(np.arange(10))
treescope.render_array(np.array([True, False, False, True, True]))

The values from 0 to 9 are rendered with solid colors, with 0 represented as white. Larger numbers are rendered using nested box patterns, with one box per digit of the number, and the color of the box indicating the value of the digit:

treescope.render_array(np.arange(1000).reshape((10,100)))
treescope.render_array(
    jnp.arange(20)[:, None] * jnp.arange(20)[None, :]
)

You can also render a single integer on its own, if you want (sometimes useful for custom visualizations). Arrayviz supports integers with up to 7 digits.

treescope.integer_digitbox(42, size="30px")
treescope.integer_digitbox(1234, size="30px")
treescope.integer_digitbox(7654321, size="30px")

Negatigve integers render the same way as positive ones, but with a black triangle in the corner indicating the sign:

treescope.render_array(np.arange(21 * 21).reshape((21, 21)) - 220)

If your data has a discrete dtype but you don’t want to render it as categorical, you can pass the continuous flag to render it as numeric instead:

treescope.render_array(np.arange(21 * 21).reshape((21, 21)) - 220, continuous=True)

Adding axis labels#

Individual axes of arrays often have some semantic meaning, e.g. the “batch” or “features” axes. If you’re using named axes, Treescope will pick those up automatically! But you can also provide them explicitly:

treescope.render_array(
    jnp.arange(20)[:, None] * jnp.arange(20)[None, :],
    axis_labels={0: "foo", 1: "bar"},
)

For some arrays, it can also be useful to associate labels with the individual indices along each axis. For instance, we might want to label a “classes” axis with each individual class, or a “sequence” axis with the tokens of the sequence.

Treescope’s array renderer allows you to pass this kind of metadata as an extra argument, and will show it to you when you hover over or click on elements of the array with your mouse.

For positional axes, you can pass any subset of the axes by position:

# Try hovering or clicking:
treescope.render_array(
    np.sin(np.linspace(0, 100, 12 * 5 * 7)).reshape((12, 5, 7)),
    axis_item_labels={
        1: ["foo", "bar", "baz", "qux", "xyz"],
        0: ["one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten", "eleven", "twelve"],
    }
)

For named axes, you can pass labels by name. Irrelevant labels are simply ignored.

Slicing and “scrubbing” with sliders#

It’s sometimes useful to only look at individual slices of a large array at a time, instead of viewing them all at once. In addition to the columns and rows arguments, render_array supports a sliders argument, which will display a slider for those axes and allow you to “scrub” through indices in it:

time = jnp.arange(100)[:, None, None]
col = np.linspace(-2, 2, 15)[None, :, None]
row = np.linspace(-2, 2, 15)[None, None, :]

values_over_time = jax.nn.sigmoid(
    0.05 * time - 2 - row - jnp.sin(2 * col - 0.1 * time)
)

# Try sliding the slider:
treescope.render_array(
    values_over_time,
    columns=[1],
    sliders=[0],
    axis_labels={
        0: "time",
        1: "row",
        2: "col",
    }
)

You can even put sliders for multiple axes simultaneously, if you want:

row_wavelength = (4 * jnp.arange(10) + 4)[:, None, None, None]
col_wavelength = (4 * jnp.arange(10) + 4)[None, :, None, None]
col = np.arange(15)[None, None, :, None]
row = np.arange(15)[None, None, None, :]

values = (
    jnp.sin(2 * np.pi * row / row_wavelength)
    * jnp.sin(2 * np.pi * col / col_wavelength)
)

# Try sliding the slider:
treescope.render_array(
    values,
    columns=[2],
    sliders=[0, 1],
    axis_item_labels={
        0: [str(v) for v in row_wavelength.squeeze((1, 2, 3))],
        1: [str(v) for v in col_wavelength.squeeze((0, 2, 3))],
    },
    axis_labels={
        0: "row_wavelength",
        1: "col_wavelength",
        2: "row",
        3: "col",
    }
)

Note: Memory usage#

One caveat to using render_array: whenever you render an array, the entire array is serialized, saved directly into the notebook output cell, and then loaded into your browser’s memory! That’s true even if you use sliders; although only part of your array is visible, all of the data is there in the notebook and in your local browser, so that it can update the view when you move the slider.

This can sometimes be useful, since it means the visualization does not require Colab/IPython to be connected, and won’t mess up any of your Python interpreter’s state. On the other hand, it’s easy to end up with very large Colab notebooks this way, and if you have many visualizations open, it can cause your web browser to bog down a bit. For a sense of scale here, a visualization of a 1000 x 1000 array adds about 5 megabytes to the size of your notebook. (Treescope will still happily render an array of that size, though!)

Given this, it’s usually a good idea to avoid saving visualizations of very large arrays into the notebook. One way to do this is to turn on “Omit code cell output when saving this notebook” mode in Colab to avoid saving output from any cell.

Visualizing array shardings (JAX only)#

If you’re using JAX, Treescope also includes utilities for visualizing array shardings. This allows you to see how arrays are laid out across your various devices. For instance, let’s shard an array over eight devices:

from jax.experimental import mesh_utils
devices = mesh_utils.create_device_mesh((8,))
pos_sharding = jax.sharding.PositionalSharding(devices).reshape((4, 2))

sharded_array = jax.device_put(
    jnp.arange(512).reshape((16, 32)), pos_sharding
)

We can render the sharding of this array using treescope.render_array_sharding:

treescope.render_array_sharding(sharded_array)

This shows which device each array element is stored on.

Array shardings are also shown whenever you visualize an array with the default array autovisualizer:

%%autovisualize
treescope.display(sharded_array)

And you can look at the sharding object itself too:

%%autovisualize
treescope.display(pos_sharding)

Using the array autovisualizer#

Treescope includes an automatic array visualizer, treescope.ArrayAutovisualizer, which will render arrays automatically whenever they are encountered inside a pretty-printed object.

Array autovisualization is enabled by default if you run treescope.basic_interactive_setup(). However, you can customize the way arrays are visualized by changing the autovisualizer settings, either for a particular output or globally.

For instance, to change the maximum size of automatic visualizations in a cell, you could run

%%autovisualize treescope.ArrayAutovisualizer(maximum_size=100)

treescope.display(sharded_array)

To change it inside a scoped Python block, you can run

with treescope.active_autovisualizer.set_scoped(
    treescope.ArrayAutovisualizer(maximum_size=100)
):
  treescope.display(sharded_array)

Or, to change it globally:

treescope.active_autovisualizer.set_globally(
    treescope.ArrayAutovisualizer(maximum_size=100)
)

See the separate “Building Custom Visualizations” tutorial for more info on how to customize automatic visualization!