<?xml version="1.0" encoding="UTF-8"?><oembed><type>video</type><version>1.0</version><html>&lt;iframe src=&quot;https://www.loom.com/embed/e554b233a8754d33b434d8090f524bf4&quot; frameborder=&quot;0&quot; width=&quot;2304&quot; height=&quot;1728&quot; webkitallowfullscreen mozallowfullscreen allowfullscreen&gt;&lt;/iframe&gt;</html><height>1728</height><width>2304</width><provider_name>Loom</provider_name><provider_url>https://www.loom.com</provider_url><thumbnail_height>1728</thumbnail_height><thumbnail_width>2304</thumbnail_width><thumbnail_url>https://cdn.loom.com/sessions/thumbnails/e554b233a8754d33b434d8090f524bf4-1636178927003.gif</thumbnail_url><duration>82</duration><title>What happens when you allocate a JAX tensor? (C++ stack trace walkthrough for TpuExecutor_Allocate)</title><description>jnp.put_device(1) is deceptively simple to write in JAX. But on a TPU, what actually happens? How does a tensor containing the value &quot;1&quot; actually get onto a TPU?

Turns out, the answer is &quot;C++&quot;, and a lot of it.

JAX in TPU mode calls into a library called libtpu. The source code isn&apos;t publicly available, but (to my amazement) has a simple C API (libtpu.h): https://twitter.com/cdleary/status/1336555074001141760?lang=en

I&apos;ll do a more detailed writeup soon, but for now I wanted to record a quick video showing off every C++ stack frame between the time you call jnp.put_device(1) all the way up to the actual memory allocation on the TPU hardware.

For reference, here&apos;s the Python stack trace:

# jnp.put_device(1) results in these python calls:
~/ml/libtpu-python/jaxtest.py:27 in  # jnp.put_device(1) happens here
~/ml/jax/jax/_src/api.py:2523 in device_put
~/ml/jax/jax/_src/tree_util.py:178 in tree_map
~/ml/jax/jax/_src/tree_util.py:178 in 
~/ml/jax/jax/_src/api.py:2523 in 
~/ml/jax/jax/core.py:272 in bind
~/ml/jax/jax/core.py:624 in process_primitive
~/ml/jax/jax/interpreters/xla.py:1708 in _device_put_impl
~/ml/jax/jax/interpreters/xla.py:301 in device_put
~/ml/jax/jax/interpreters/xla.py:309 in _device_put_array # then jumps into C++ here

Notice how little info you get. Almost *none* of the answers are in the Python codebase. At least, none of the answers to the hardcore engineering questions I was interested in.

Here&apos;s the C++ trace. This is the actual code that drives the TPU in production:

# jnp.put_device(1) results in these C++ calls:

*** Begin stack trace ***
https://github.com/tensorflow/tensorflow/tree/v2.7.0/tensorflow/compiler/xla/python/pytree.cc#L280
PyTreeDef::Unflatten (py::iterable) const
  (compiler/xla/python/pytree.cc:280)

https://github.com/tensorflow/tensorflow/tree/v2.7.0/tensorflow/compiler/xla/python/pytree.cc#L232
py::object
PyTreeDef::UnflattenImpl (py::iterable) const
  (compiler/xla/python/pytree.cc:232)

https://github.com/tensorflow/tensorflow/tree/v2.7.0/tensorflow/compiler/xla/python/py_client.cc#L156
PyClient::BufferFromPyval (py::handle, PjRtDevice*, bool, PjRtClient::HostBufferSemantics)
  (compiler/xla/python/py_client.cc:156)

https://github.com/tensorflow/tensorflow/tree/v2.7.0/tensorflow/compiler/xla/python/py_values.cc#L250
DevicePut (py::handle, PjRtDevice*, const DevicePutOptions &amp;)
  (compiler/xla/python/py_values.cc:150)

https://github.com/tensorflow/tensorflow/tree/v2.7.0/tensorflow/compiler/xla/python/py_values.cc#L148
HandleNumpyArray (py::handle, PjRtDevice*, const DevicePutOptions &amp;)
  (compiler/xla/python/py_values.cc:148)

https://github.com/tensorflow/tensorflow/tree/v2.7.0/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc#L680
PjRtStreamExecutorClient::BufferFromHostBuffer (const void *, const Shape &amp;, PjRtClient::HostBufferSemantics, function, PjRtDevice*)
  (compiler/xla/pjrt/pjrt_stream_executor_client.cc:680)

https://github.com/tensorflow/tensorflow/tree/v2.7.0/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc#L363
AllocateDestinationBuffer (const Shape &amp;, PjRtDevice*, LocalDeviceState*, stream_executor::Stream*, bool, PjRtClient*, shared_ptr)
  (compiler/xla/pjrt/pjrt_stream_executor_client.cc:363)

https://github.com/tensorflow/tensorflow/tree/v2.7.0/tensorflow/compiler/xla/service/transfer_manager.cc#L388
TransferManager::AllocateScopedShapedBuffer (const Shape &amp;, stream_executor::DeviceMemoryAllocator*, int, const fn)
  (compiler/xla/service/transfer_manager.cc:390)

https://github.com/tensorflow/tensorflow/tree/v2.7.0/tensorflow/stream_executor/stream_executor_pimpl.cc#L901
stream_executor::StreamExecutorMemoryAllocator::Allocate (int, uint64_t, bool, int64_t)
  (stream_executor/stream_executor_pimpl.cc:901)

https://github.com/tensorflow/tensorflow/tree/v2.7.0/tensorflow/stream_executor/stream_executor_pimpl.cc#L487
stream_executor::StreamExecutor::Allocate (uint64_t, int64_t)
  (stream_executor/stream_executor_pimpl.cc:487)

https://github.com/tensorflow/tensorflow/tree/v2.7.0/tensorflow/stream_executor/tpu/tpu_executor.cc#L195
tensorflow::tpu::TpuExecutor::Allocate (uint64_t, int64_t)
  (stream_executor/tpu/tpu_executor.cc:195)

https://github.com/tensorflow/tensorflow/tree/v2.7.0/tensorflow/core/platform/default/stacktrace.h#L42
tensorflow::CurrentStackTrace (bool)
  (core/platform/default/stacktrace.h:42)
*** End stack trace ***


This video was also to show off my custom build of iTerm2 (https://github.com/gnachman/iTerm2/pull/454)
which drops me into CLion when I command-click filenames with line numbers. :)

-- Shawn
@theshawwn on twitter:  https://twitter.com/theshawwn
shawwn on github: https://github.com/shawwn
sillysaurusx on hn: https://news.ycombinator.com/threads?id=sillysaurusx</description></oembed>