Copyright 2024 The Treescope Authors.
Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
Building Custom Visualizations#
Treescope allows you to customize the renderings it generates, to support more advanced visualization workflows. This customization can be done in a few different ways:
You can use the
treescope.figures
subpackage to build your own top-level visualizations, by styling text and interleaving figures or array visualizations.You can define your own
treescope.Autovisualizer
, and use it to automatically add rich visualizations to internal parts of rendered objects.You can implement the
__treescope_repr__
method on your custom types, to add support for rendering them with Treescope.
Setup#
To run this notebook, you need a Python environment with treescope
and its dependencies installed.
In Colab or Kaggle, you can install it using the following command:
try:
import treescope
except ImportError:
!pip install treescope
from __future__ import annotations
from typing import Any
import jax
import jax.numpy as jnp
import numpy as np
import IPython
import treescope
treescope.basic_interactive_setup()
import plotly.io
import plotly.express as px
# Treescope uses similar embed settings as Colab, so configure plotly to render
# like in colab:
plotly.io.renderers.default = "colab"
Building simple figures with treescope.figures
#
The treescope.figures
submodule contains helper functions for arranging and styling information, using Treescope’s internal rendering system. This can be used to produce custom outputs that arrange Treescope components in different ways.
Here’s a demo of some of the building blocks:
treescope.figures.inline(
"This is a simple inline output figure. You can ",
treescope.figures.bolded("emphasize"),
" parts of the output, or ",
treescope.figures.with_color("change their color, ", "red"),
"or even\n"
"embed Python objects like ",
[1, 2, 3],
" or even array visualizations like\n",
np.linspace(-10, 10, 20),
".\nIt's also possible to indent parts of the input, like",
treescope.figures.indented(treescope.figures.inline(
"this. Indents apply to\nnewlines and\nembedded objects too:\n",
[1, 2, 3, 4, 5]
)),
"You can also embed figures from other libraries:\n",
px.histogram(
jax.random.uniform(jax.random.key(0), (1000,)),
width=400, height=200
).update_layout(
margin=dict(l=20, r=20, t=20, b=20)
),
"\nAdditionally, you can insert colored \"digitboxes\", which Trescope uses\n",
"to render token IDs: ",
treescope.integer_digitbox(1, label="1"),
" ",
treescope.integer_digitbox(2, label="2"),
" ",
treescope.integer_digitbox(12345, label="12345"),
". And you can add ",
treescope.figures.text_on_color(
"text with a ",
value=0.2, vmax=1.0
),
treescope.figures.text_on_color(
"colormapped",
value=-1.0, vmax=1.0
),
treescope.figures.text_on_color(
" background color,",
value=0.8, vmax=1.0
),
"\nwhich can be useful for showing token probabilities or similar per-token info."
)
See the documentation for treescope.figures
for more info.
Defining a custom automatic subtree visualizer#
As discussed in the other tutorials, Treescope supports automatically visualizing arrays inside rendered objects. For instance:
np.arange(10)
[np.arange(10), np.linspace(-10,10,20)]
Automatic array visualization is a special case of a more general treescope feature, which lets you render arbitrary figures at arbitrary points in pretty-printed PyTrees. To customize automatic visualization, you define an autovisualizer function, with the following signature:
def autovisualizer_fn(
value: Any,
path: tuple[Any, ...] | None,
) -> pz.ts.IPythonVisualization | pz.ts.ChildAutovisualizer | None:
...
This function will be called on every subtree of the rendered tree, and can return pz.ts.IPythonVisualization(some_figure)
to replace the subtree with a visualization, or None
to process the subtree normally. (It can also return pz.ts.ChildAutovisualizer
if the subtree should be rendered with a different autovisualizer.)
For instance, we can write an autovisualizer that always formats arrays in continuous mode:
def my_continuous_autovisualizer(
value: Any,
path: tuple[Any, ...] | None,
):
if isinstance(value, np.ndarray):
return treescope.IPythonVisualization(
treescope.render_array(value, continuous=True, around_zero=False),
replace=True,
)
with treescope.active_autovisualizer.set_scoped(
my_continuous_autovisualizer
):
IPython.display.display({
"foo": np.arange(10)[:, None] * np.arange(10)[None, :],
"bar": np.sin(np.arange(100) * 0.1).reshape((10,10))
})
Or, add additional metadata:
def my_verbose_autovisualizer(
value: Any,
path: tuple[Any, ...] | None,
):
if isinstance(value, np.ndarray):
size = value.size
token_groups = [
(id(value) // div) % 1000
for div in (1000000000000, 1000000, 1000, 1)
]
return treescope.IPythonVisualization(
treescope.figures.inline(
"Hello world!\n",
treescope.render_array(value),
f"\nThis array contains {size} elements and has Python id {id(value):,}, which you could tokenize as ",
treescope.integer_digitbox(token_groups[0], label=str(token_groups[0])),
" ", treescope.integer_digitbox(token_groups[1], label=str(token_groups[1])),
" ", treescope.integer_digitbox(token_groups[2], label=str(token_groups[2])),
" ", treescope.integer_digitbox(token_groups[3], label=str(token_groups[3])),
f"\nThe path to this node is {path}",
),
replace=False
)
with treescope.active_autovisualizer.set_scoped(
my_verbose_autovisualizer
):
IPython.display.display({
"foo": np.arange(10)[:, None] * np.arange(10)[None, :],
"bar": np.sin(np.arange(100) * 0.1).reshape((10,10))
})
You can even render values using an external plotting library like plotly
!
Treescope can inline any type of figure that has a rich HTML representation (specifically, any object that defines the magic _repr_html_
method expected by Colab’s IPython kernel.)
def my_plotly_autovisualizer(
value: Any,
path: tuple[Any, ...] | None,
):
if isinstance(value, (np.ndarray, jax.Array)):
return treescope.IPythonVisualization(
px.histogram(
value.flatten(),
width=400, height=200
).update_layout(
margin=dict(l=20, r=20, t=20, b=20)
)
)
with treescope.active_autovisualizer.set_scoped(
my_plotly_autovisualizer
):
IPython.display.display({
"foo": np.arange(10)[:, None] * np.arange(10)[None, :],
"bar": np.sin(np.arange(100) * 0.1).reshape((10,10))
})
You can also pass custom visualizers to the %%autovisualize
magic to let it handle the set_scoped
boilerplate for you:
%%autovisualize my_plotly_autovisualizer
{
"foo": np.arange(10)[:, None] * np.arange(10)[None, :],
"bar": np.sin(np.arange(100) * 0.1).reshape((10,10))
}
Adding support for rendering custom types to Treescope#
You can customize how Treescope renders your type by implementing the __treescope_repr__
method, with the signature
class MyCustomType:
...
def __treescope_repr__(
self,
path: str,
subtree_renderer: Callable[
[Any, str | None], treescope.rendering_parts.Rendering
],
) -> treescope.rendering_parts.Rendering | type(NotImplemented):
...
Here path
is a string path to this node from the root node, and subtree_renderer
is a function that maps a child node and its path to a rendering for that child node.
The type treescope.rendering_parts.Rendering
is Treescope’s internal representation of a rendered object, which can be converted to either text or HTML. The simplest way to build a rendering is to use one of the high-level helpers in treescope.repr_lib
. For instance:
class MySimpleType:
def __init__(self, foo):
self.foo = foo
def __treescope_repr__(self, path, subtree_renderer):
return treescope.repr_lib.render_object_constructor(
object_type=type(self),
attributes={"foo": self.foo},
path=path,
subtree_renderer=subtree_renderer,
# Pass `roundtrippable=True` only if you can rebuild your object by
# calling `__init__` with these attributes!
roundtrippable=True,
)
MySimpleType(123)
For more advanced customization, you can also directly build a rendering using the low-level definitions in treescope.rendering_parts
.
If your type is an array or a tensor (like np.ndarray
or jax.Array
), you can also add support for automatic visualization by implementing an NDArrayAdapter
for it. See the documentation for treescope.ndarray_adapters.NDArrayAdapter
for details.
You can then implement the special method __treescope_ndarray_adapter__
, with signature,
class MyCustomType:
...
def __treescope_ndarray_adapter__(self) -> NDArrayAdapter:
...
which should return an adapter for your type.
(Alternatively, custom types can also be registered using the global registries treescope.type_registries.TREESCOPE_HANDLER_REGISTRY
and treescope.type_registries.NDARRAY_ADAPTER_REGISTRY
.)