NDArrayAdapter#

class treescope.ndarray_adapters.NDArrayAdapter[source]#

Bases: ABC, Generic[T]

An adapter to support rendering a multi-dimensional array (tensor) type.

Methods

get_array_data_with_truncation(array, mask, ...)

Returns a numpy array with truncated array (and mask) data.

get_array_summary(array, fast)

Summarizes the contents of the given array.

get_axis_info_for_array_data(array)

Returns axis information for each axis in the given array.

get_numpy_dtype(array)

Returns the numpy dtype of the given array.

get_sharding_info_for_array_data(array)

Summarizes the sharding of the given array's data.

should_autovisualize(array)

Returns True if the given array should be automatically visualized.

Inherited Methods

(expand to view inherited methods)

__init__()

abstract get_array_data_with_truncation(array: T, mask: T | None, edge_items_per_axis: tuple[int | None, ...]) tuple[np.ndarray, np.ndarray][source]#

Returns a numpy array with truncated array (and mask) data.

This method should construct a numpy array whose contents are a truncated version of the given array’s data; this array will be used to construct the actual array visualization. It is also responsible for broadcasting the mask appropriately and returning a compatible truncation of it.

This method may be called many times when rendering a large structure of arrays (once per array), so it should be as fast as possible. We suggest doing truncation on an accelerator device and then copying the result, if possible, to avoid unnecessary data transfer.

Parameters:
  • array – The array to get data for.

  • mask – An optional mask array provided by the user, which should be broadcast-compatible with array. (If it is not compatible, the user has provided an invalid mask, and this method should raise an informative exception.) Can be None if no mask is provided.

  • edge_items_per_axis – A tuple with one entry for each axis in the array. Each entry is either the number of items to keep on each side of this axis, or None to keep all items. The ordering will be consistent with the axis order returned by get_axis_info_for_array_data, i.e. the entry k in edge_items corresponds to the entry k in the axis info tuple, regardless of the logical indices or axis names.

Returns:

A tuple (truncated_data, truncated_mask). truncated_data should be a numpy array with a truncated version of the given array’s data. If entry k in edge_items is None, axis k should have the same size as the size field of the entry k returned by get_axis_info_for_array_data. If entry k in edge_items is not None, axis k should have a size of edge_items[k] * 2 + 1, and the middle element can be arbitrary. truncated_mask should be a numpy array with the same shape as truncated_data containing a truncated, broadcasted version of the mask; the middle element of the mask must be False for each truncated axis.

abstract get_array_summary(array: T, fast: bool) str[source]#

Summarizes the contents of the given array.

The summary returned by this method will be used as a one-line summary of the array in treescope when automatically visualized.

If the fast argument is True, the method should return a summary that can be computed quickly, ideally without any device computation. If it is False, the method can return a more detailed summary, but it should still be fast enough to be called many times when rendering a large structure of arrays.

Parameters:
  • array – The array to summarize.

  • fast – Whether to return a fast summary that can be computed without expensive device computation.

Returns:

A summary of the given array’s contents. The summary should be a single line of text. It will be wrapped between angle brackets (< and >) when rendered.

abstract get_axis_info_for_array_data(array: T) tuple[AxisInfo, ...][source]#

Returns axis information for each axis in the given array.

This method should return a tuple with an AxisInfo entry for each axis in the array. Array axes can be one of three types:

  • Positional axes have an index and a size, and can be accessed by position. This is common in ordinary NDArrays.

  • Named positionless axes have a name and a size, and can be accessed by name only. This is how penzai.core.named_axes treats named axes.

  • Named positional axes have an index, a name, and a size, and can be accessed by either position or name. This is how PyTorch treats named axes.

Note that positional axes have an explicit “logical index”, which may or may not match their position in the underlying array data; this makes it possible to support “views” of underlying array data that have a different axis ordering than the original data. (penzai.core.named_axes uses this.)

Parameters:

array – The array to get axis information for.

Returns:

A tuple with an AxisInfo entry for each axis in the array. The ordering must be consistent with the ordering expected by get_array_data_with_truncation.

get_numpy_dtype(array: T) np.dtype | None[source]#

Returns the numpy dtype of the given array.

This should match the dtype of the array returned by get_array_data_with_truncation.

Parameters:

array – The array to summarize.

Returns:

The numpy dtype of the given array, or None if the array does not have a numpy dtype.

get_sharding_info_for_array_data(array: T) ShardingInfo | None[source]#

Summarizes the sharding of the given array’s data.

The summary returned by this method will be used to render a sharding for the array when automatic visualization is enabled.

Parameters:

array – The array to summarize.

Returns:

A summary of the given array’s sharding, or None if it does not have a sharding.

should_autovisualize(array: T) bool[source]#

Returns True if the given array should be automatically visualized.

If this method returns True, the array will be automatically visualized by the array visualizer if it is enabled.

Parameters:

array – The array to possibly visualize.

Returns:

True if the given array should be automatically visualized.