Copyright 2024 The Treescope Authors.
Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
Visualizing Arrays with Treescope#
High-dimensional NDArray (or tensor) data is common in many machine learning settings, but most plotting libraries are designed for either 2D image data or 1D time series data. Treescope includes a powerful arbitrarily-high-dimensional-array visualizer designed to make it easy to quickly summarize NDArrays without having to write manual plotting logic.
This notebook is primarily written in terms of Numpy arrays, but it also works for other types of array, including JAX arrays, PyTorch tensors, and Penzai NamedArrays!
Setup#
To run this notebook, you need a Python environment with treescope
and its dependencies installed.
In Colab or Kaggle, you can install it using the following command:
try:
import treescope
except ImportError:
!pip install treescope
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
from __future__ import annotations
from typing import Any
import jax
import jax.numpy as jnp
import numpy as np
import IPython
import treescope
treescope.register_autovisualize_magic()
Visualizing NDArrays with treescope.render_array
#
Visualizing numeric data and customizing colormaps#
Arrays can be directly rendered using default settings by passing them to treescope.render_array
:
help(treescope.render_array)
Help on function render_array in module treescope._internal.api.arrayviz:
render_array(array: 'ArrayInRegistry', *, columns: 'Sequence[AxisName | int]' = (), rows: 'Sequence[AxisName | int]' = (), sliders: 'Sequence[AxisName | int]' = (), valid_mask: 'Any | None' = None, continuous: "bool | Literal['auto']" = 'auto', around_zero: "bool | Literal['auto']" = 'auto', vmax: 'float | None' = None, vmin: 'float | None' = None, trim_outliers: 'bool' = True, dynamic_colormap: "bool | Literal['auto']" = 'auto', colormap: 'list[tuple[int, int, int]] | None' = None, truncate: 'bool' = False, maximum_size: 'int' = 10000, cutoff_size_per_axis: 'int' = 512, minimum_edge_items: 'int' = 5, axis_item_labels: 'dict[AxisName | int, list[str]] | None' = None, value_item_labels: 'dict[int, str] | None' = None, axis_labels: 'dict[AxisName | int, str] | None' = None, pixels_per_cell: 'int | float' = 7) -> 'figures_impl.TreescopeFigure'
Renders an array (positional or named) to a displayable HTML object.
Each element of the array is rendered to a fixed-size square, with its
position determined based on its index, and with each level of x and y axis
represented by a "faceted" plot.
Out-of-bounds or otherwise unusual data is rendered with an annotation:
* "X" means a value was NaN (for continuous data) or went out-of-bounds for
the integer palette (for discrete data).
* "I" or "-I" means a value was infinity or negative infinity.
* "+" or "-" means a value was finite but went outside the bounds of the
colormap (e.g. it was larger than ``vmax`` or smaller than ``vmin``). By
default this applies to values more than 3 standard deviations outside the
mean.
* Four light dots on grey means a value was masked out by ``valid_mask``, or
truncated due to the maximum size or axis cutoff thresholds.
By default, this method automatically chooses a color rendering strategy based
on the arguments:
* If an explicit colormap is provided:
* If ``continuous`` is True, the provided colors are interpreted as color
stops and interpolated between.
* If ``continuous`` is False, the provided colors are interpreted as an
indexed color palette, and each index of the palette is used to render
the corresponding integer, starting from zero.
* Otherwise:
* If ``continuous`` is True:
* If ``around_zero`` is True, uses the diverging colormap
`default_diverging_colormap`. The initial value of this is a truncated
version of the perceptually-uniform "Balance" colormap from cmocean,
with blue for positive numbers and red for negative ones.
* If ``around_zero`` is False, uses the sequential colormap
`default_sequential_colormap`.The initial value of this is the
perceptually-uniform "Viridis" colormap from matplotlib.
* If ``continuous`` is False, uses a pattern-based "digitbox" rendering
strategy to render integers up to 9,999,999 as nested squares, with one
square per integer digit and digit colors drawn from the D3 Category20
colormap.
Args:
array: The array to render. The type of this array must be registered in
the `type_registries.NDARRAY_ADAPTER_REGISTRY`.
columns: Sequence of axis names or positional axis indices that should be
placed on the x axis, from innermost to outermost. If not provided,
inferred automatically.
rows: Sequence of axis names or positional axis indices that should be
placed on the y axis, from innermost to outermost. If not provided,
inferred automatically.
sliders: Sequence of axis names or positional axis indices for which we
should show only a single slice at a time, with the index determined
with a slider.
valid_mask: Optionally, a boolean array with the same shape (and, if
applicable, axis names) as `array`, which is True for the locations that
we should actually render, and False for locations that do not have
valid array data.
continuous: Whether to interpret this array as numbers along the real
line, and visualize using an interpolated colormap. If "auto", inferred
from the dtype of `array`.
around_zero: Whether the array data should be rendered symmetrically
around zero using a diverging colormap, scaled based on the absolute
magnitude of the inputs, instead of rescaled to be between the min and
max of the data. If "auto", treated as True unless both `vmin` and
`vmax` are set to incompatible values.
vmax: Largest value represented in the colormap. If omitted and
around_zero is True, inferred as ``max(abs(array))`` or as ``-vmin``. If
omitted and around_zero is False, inferred as ``max(array)``.
vmin: Smallest value represented in the colormap. If omitted and
around_zero is True, inferred as ``-max(abs(array))`` or as ``-vmax``.
If omitted and around_zero is False, inferred as ``min(array)``.
trim_outliers: Whether to try to trim outliers when inferring ``vmin`` and
``vmax``. If True, clips them to 3 standard deviations away from the
mean (or 3 sqrt-second-moments around zero) if they would otherwise
exceed it.
dynamic_colormap: Whether to dynamically adjust the colormap based on
mouse hover. Requires a continuous colormap, and ``around_zero=True``.
If "auto", will be enabled for continuous arrays if ``around_zero`` is
True and neither ``vmin`` nor ``vmax`` are provided.
colormap: An optional explicit colormap to use, represented as a list of
``(r,g,b)`` tuples, where each channel is between 0 and 255. A good
place to get colormaps is the ``palettable`` package, e.g. you can pass
something like ``palettable.matplotlib.Inferno_20.colors``.
truncate: Whether or not to truncate the array to a smaller size before
rendering.
maximum_size: Maximum numer of elements of an array to show. Arrays larger
than this will be truncated along one or more axes. Ignored unless
``truncate`` is True.
cutoff_size_per_axis: Maximum number of elements of each individual axis
to show without truncation. Any axis longer than this will be truncated,
with their visual size increasing logarithmically with the true axis
size beyond this point. Ignored unless ``truncate`` is True.
minimum_edge_items: How many values to keep along each axis for truncated
arrays. We may keep more than this up to the budget of maximum_size.
Ignored unless ``truncate`` is True.
axis_item_labels: An optional mapping from axis names/positions to a list
of strings, of the same length as the axis length, giving a label to
each item along that axis. For instance, this could be the token string
corresponding to each position along a sequence axis, or the class label
corresponding to each category across a classifier's output axis. This
is shown in the tooltip when hovering over a pixel, and shown below the
array when a pixel is clicked on. For convenience, names in this
dictionary that don't match any axes in the input are simply ignored, so
that you can pass the same labels while rendering arrays that may not
have the same axis names.
value_item_labels: For categorical data, an optional mapping from each
value to a string. For instance, this could be the token value
corresponding to each token ID in a sequence of tokens.
axis_labels: Optional mapping from axis names / indices to the labels we
should use for that axis. If not provided, we label the named axes with
their names and the positional axes with "axis {i}", and also add th
axis size.
pixels_per_cell: Size of each rendered array element in pixels, between 1
and 21 inclusive. This controls the zoom level of the rendering. Array
elements are always drawn at 7 pixels per cell and then rescaled, so
out-of-bounds annotations and "digitbox" integer value patterns may not
display correctly at fewer than 7 pixels per cell.
Returns:
An object which can be rendered in an IPython notebook, containing the
HTML source of an arrayviz rendering.
my_array = np.cos(np.arange(300).reshape((10,30)) * 0.2)
treescope.render_array(my_array)
Things to notice:
The visualization is interactive! (Try zooming in and out, hovering over the array to inspect individual elements, or clicking to remember a particular element.)
The shape of the array can be read off by looking at the axis labels.
Pixels are always square in arrayviz renderings. (In fact, they are always exactly 7 pixels by 7 pixels at zoom level 1.)
The default rendering strategy uses a diverging colormap centered at zero, with blue for positive numbers and red for negative ones, to show you the absolute magnitude and sign of the array. You can toggle to a relative mode by passing the argument around_zero=False
:
treescope.render_array(my_array, around_zero=False)
You can also customize the upper and lower bounds of the colormap by passing vmin
and/or vmax
:
treescope.render_array(my_array, vmax=0.7)
In this case, the array has values outside of our specified colormap bounds; those out-of-bounds values are rendered with “+” and “-” to indicate that they’ve been clipped.
Since we didn’t pass around_zero=False
, it automatically set vmin
to -vmax
for us. You can choose to set both explicitly too:
treescope.render_array(my_array, vmin=-0.1, vmax=0.7)
If you want to customize the way colors are rendered, you can pass a custom colormap as a list of (R, G, B) color tuples:
import palettable
treescope.render_array(my_array, colormap=palettable.matplotlib.Inferno_20.colors)
Matplotlib is building the font cache; this may take a moment.
treescope.render_array(my_array, colormap=palettable.cmocean.sequential.Speed_20.colors)
Visualizing high-dimensional arrays#
So far we’ve been looking at an array with two axes, but Treescope’s array renderer works out-of-the-box with arbitrarily high-dimensional arrays as well:
my_4d_array = np.cos(np.arange(5*6*7*8).reshape((5,6,7,8)) * 0.1)
treescope.render_array(my_4d_array)
For high-dimensional arrays, the individual axis labels indicate which level of the plot corresponds to which axis. Above, each 7x8 square facet represents a slice my_4d_array[i,j,:,:]
, with individual pixels ranging along axis 2 and axis 3; this is denoted by the axis2
and axis3
labels for that facet. The six columns correspond to slices along axis 1, and the five rows correspond to slices along axis 0, as denoted by the outermost labels for those axes.
You can control which axes get assigned to which direction if you want, specified from innermost to outermost:
treescope.render_array(my_4d_array, columns=[2, 0, 1])
Note that the gap between the “axis0” groups is twice as large as the gap between “axis2” groups, so that you can visually distinguish the groups.
Treescope can also visualize Penzai’s NamedArray
, and takes labels from them. This means that, if your code is written with NamedArrays
, you get labeled visualizations for free!
from penzai import pz
my_named_array = pz.nx.wrap(
jax.random.normal(jax.random.key(1), (10, 4, 16)),
).tag("query_seq", "heads", "embed")
treescope.render_array(my_named_array, columns=["embed", "heads"])
(See the NamedArray tutorial for more information on how NamedArray
s work in penzai.)
Identifying extreme or invalid array values#
By default, Treescope’s array renderer tries to configure the colormap to show interesting detail, clipping outliers. Specifically, it limits the colormap to 3 standard deviations away from the mean (or, technically, from zero if around_zero
is set):
my_outlier_array = np.cos(np.arange(300).reshape((10,30)) * 0.2)
my_outlier_array[4, 2] = 10.0
treescope.render_array(my_outlier_array)
It also annotates any invalid array values by drawing annotations on top of the visualization:
numerator = np.linspace(-5, 5, 31)
denominator = np.linspace(-1, 1, 13)
array_with_infs_and_nans = numerator[None, :] / denominator[:, None]
treescope.render_array(array_with_infs_and_nans)
/tmp/ipykernel_705/2010862203.py:3: RuntimeWarning: divide by zero encountered in divide
array_with_infs_and_nans = numerator[None, :] / denominator[:, None]
/tmp/ipykernel_705/2010862203.py:3: RuntimeWarning: invalid value encountered in divide
array_with_infs_and_nans = numerator[None, :] / denominator[:, None]
Above, “I
” (white on a blue background) denotes positive infinity, “-I
” (white on a red background) denotes negative infinity, and “X
” (in magenta on a black background) denotes NaN. (You can also see the outlier-clipping behavior clipping a few of the largest finite values here.)
This works in relative mode too:
treescope.render_array(array_with_infs_and_nans, around_zero=False)
If you want, you can mask out data by providing a “valid mask”. Only values where the mask is True will be rendered; masked-out data is shown in gray with black dots.
valid_mask = np.isfinite(array_with_infs_and_nans) & (np.abs(array_with_infs_and_nans) < 10)
treescope.render_array(
array_with_infs_and_nans,
valid_mask=valid_mask,
)
Visualizing categorical data#
Treescope’s array renderer also supports rendering categorical data, even with very high numbers of categories. Data with a discrete (integer or boolean) dtype is rendered as categorical by default, with different colors for different categories:
treescope.render_array(np.arange(10))
treescope.render_array(np.array([True, False, False, True, True]))
The values from 0 to 9 are rendered with solid colors, with 0 represented as white. Larger numbers are rendered using nested box patterns, with one box per digit of the number, and the color of the box indicating the value of the digit:
treescope.render_array(np.arange(1000).reshape((10,100)))
treescope.render_array(
jnp.arange(20)[:, None] * jnp.arange(20)[None, :]
)
You can also render a single integer on its own, if you want (sometimes useful for custom visualizations). Arrayviz supports integers with up to 7 digits.
treescope.integer_digitbox(42, size="30px")
treescope.integer_digitbox(1234, size="30px")
treescope.integer_digitbox(7654321, size="30px")
Negatigve integers render the same way as positive ones, but with a black triangle in the corner indicating the sign:
treescope.render_array(np.arange(21 * 21).reshape((21, 21)) - 220)
If your data has a discrete dtype but you don’t want to render it as categorical, you can pass the continuous
flag to render it as numeric instead:
treescope.render_array(np.arange(21 * 21).reshape((21, 21)) - 220, continuous=True)
Adding axis labels#
Individual axes of arrays often have some semantic meaning, e.g. the “batch” or “features” axes. If you’re using named axes, Treescope will pick those up automatically! But you can also provide them explicitly:
treescope.render_array(
jnp.arange(20)[:, None] * jnp.arange(20)[None, :],
axis_labels={0: "foo", 1: "bar"},
)
For some arrays, it can also be useful to associate labels with the individual indices along each axis. For instance, we might want to label a “classes” axis with each individual class, or a “sequence” axis with the tokens of the sequence.
Treescope’s array renderer allows you to pass this kind of metadata as an extra argument, and will show it to you when you hover over or click on elements of the array with your mouse.
For positional axes, you can pass any subset of the axes by position:
# Try hovering or clicking:
treescope.render_array(
np.sin(np.linspace(0, 100, 12 * 5 * 7)).reshape((12, 5, 7)),
axis_item_labels={
1: ["foo", "bar", "baz", "qux", "xyz"],
0: ["one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten", "eleven", "twelve"],
}
)
For named axes, you can pass labels by name. Irrelevant labels are simply ignored.
Slicing and “scrubbing” with sliders#
It’s sometimes useful to only look at individual slices of a large array at a time, instead of viewing them all at once. In addition to the columns
and rows
arguments, render_array
supports a sliders
argument, which will display a slider for those axes and allow you to “scrub” through indices in it:
time = jnp.arange(100)[:, None, None]
col = np.linspace(-2, 2, 15)[None, :, None]
row = np.linspace(-2, 2, 15)[None, None, :]
values_over_time = jax.nn.sigmoid(
0.05 * time - 2 - row - jnp.sin(2 * col - 0.1 * time)
)
# Try sliding the slider:
treescope.render_array(
values_over_time,
columns=[1],
sliders=[0],
axis_labels={
0: "time",
1: "row",
2: "col",
}
)
You can even put sliders for multiple axes simultaneously, if you want:
row_wavelength = (4 * jnp.arange(10) + 4)[:, None, None, None]
col_wavelength = (4 * jnp.arange(10) + 4)[None, :, None, None]
col = np.arange(15)[None, None, :, None]
row = np.arange(15)[None, None, None, :]
values = (
jnp.sin(2 * np.pi * row / row_wavelength)
* jnp.sin(2 * np.pi * col / col_wavelength)
)
# Try sliding the slider:
treescope.render_array(
values,
columns=[2],
sliders=[0, 1],
axis_item_labels={
0: [str(v) for v in row_wavelength.squeeze((1, 2, 3))],
1: [str(v) for v in col_wavelength.squeeze((0, 2, 3))],
},
axis_labels={
0: "row_wavelength",
1: "col_wavelength",
2: "row",
3: "col",
}
)
Note: Memory usage#
One caveat to using render_array
: whenever you render an array, the entire array is serialized, saved directly into the notebook output cell, and then loaded into your browser’s memory! That’s true even if you use sliders
; although only part of your array is visible, all of the data is there in the notebook and in your local browser, so that it can update the view when you move the slider.
This can sometimes be useful, since it means the visualization does not require Colab/IPython to be connected, and won’t mess up any of your Python interpreter’s state. On the other hand, it’s easy to end up with very large Colab notebooks this way, and if you have many visualizations open, it can cause your web browser to bog down a bit. For a sense of scale here, a visualization of a 1000 x 1000
array adds about 5 megabytes to the size of your notebook. (Treescope will still happily render an array of that size, though!)
Given this, it’s usually a good idea to avoid saving visualizations of very large arrays into the notebook. One way to do this is to turn on “Omit code cell output when saving this notebook” mode in Colab to avoid saving output from any cell.
Visualizing array shardings (JAX only)#
If you’re using JAX, Treescope also includes utilities for visualizing array shardings. This allows you to see how arrays are laid out across your various devices. For instance, let’s shard an array over eight devices:
from jax.experimental import mesh_utils
devices = mesh_utils.create_device_mesh((8,))
pos_sharding = jax.sharding.PositionalSharding(devices).reshape((4, 2))
sharded_array = jax.device_put(
jnp.arange(512).reshape((16, 32)), pos_sharding
)
We can render the sharding of this array using treescope.render_array_sharding
:
treescope.render_array_sharding(sharded_array)
This shows which device each array element is stored on.
Array shardings are also shown whenever you visualize an array with the default array autovisualizer:
%%autovisualize
treescope.display(sharded_array)
And you can look at the sharding object itself too:
%%autovisualize
treescope.display(pos_sharding)
Using the array autovisualizer#
Treescope includes an automatic array visualizer, treescope.ArrayAutovisualizer
, which will render arrays automatically whenever they are encountered inside a pretty-printed object.
Array autovisualization is enabled by default if you run treescope.basic_interactive_setup()
. However, you can customize the way arrays are visualized by changing the autovisualizer settings, either for a particular output or globally.
For instance, to change the maximum size of automatic visualizations in a cell, you could run
%%autovisualize treescope.ArrayAutovisualizer(maximum_size=100)
treescope.display(sharded_array)
To change it inside a scoped Python block, you can run
with treescope.active_autovisualizer.set_scoped(
treescope.ArrayAutovisualizer(maximum_size=100)
):
treescope.display(sharded_array)
Or, to change it globally:
treescope.active_autovisualizer.set_globally(
treescope.ArrayAutovisualizer(maximum_size=100)
)
See the separate “Building Custom Visualizations” tutorial for more info on how to customize automatic visualization!