{"type":"video","version":"1.0","html":"<iframe src=\"https://www.loom.com/embed/e554b233a8754d33b434d8090f524bf4\" frameborder=\"0\" width=\"2304\" height=\"1728\" webkitallowfullscreen mozallowfullscreen allowfullscreen></iframe>","height":1728,"width":2304,"provider_name":"Loom","provider_url":"https://www.loom.com","thumbnail_height":1728,"thumbnail_width":2304,"thumbnail_url":"https://cdn.loom.com/sessions/thumbnails/e554b233a8754d33b434d8090f524bf4-1636178927003.gif","duration":82,"title":"What happens when you allocate a JAX tensor? (C++ stack trace walkthrough for TpuExecutor_Allocate)","description":"jnp.put_device(1) is deceptively simple to write in JAX. But on a TPU, what actually happens? How does a tensor containing the value \"1\" actually get onto a TPU?\n\nTurns out, the answer is \"C++\", and a lot of it.\n\nJAX in TPU mode calls into a library called libtpu. The source code isn't publicly available, but (to my amazement) has a simple C API (libtpu.h): https://twitter.com/cdleary/status/1336555074001141760?lang=en\n\nI'll do a more detailed writeup soon, but for now I wanted to record a quick video showing off every C++ stack frame between the time you call jnp.put_device(1) all the way up to the actual memory allocation on the TPU hardware.\n\nFor reference, here's the Python stack trace:\n\n# jnp.put_device(1) results in these python calls:\n~/ml/libtpu-python/jaxtest.py:27 in  # jnp.put_device(1) happens here\n~/ml/jax/jax/_src/api.py:2523 in device_put\n~/ml/jax/jax/_src/tree_util.py:178 in tree_map\n~/ml/jax/jax/_src/tree_util.py:178 in \n~/ml/jax/jax/_src/api.py:2523 in \n~/ml/jax/jax/core.py:272 in bind\n~/ml/jax/jax/core.py:624 in process_primitive\n~/ml/jax/jax/interpreters/xla.py:1708 in _device_put_impl\n~/ml/jax/jax/interpreters/xla.py:301 in device_put\n~/ml/jax/jax/interpreters/xla.py:309 in _device_put_array # then jumps into C++ here\n\nNotice how little info you get. Almost *none* of the answers are in the Python codebase. At least, none of the answers to the hardcore engineering questions I was interested in.\n\nHere's the C++ trace. This is the actual code that drives the TPU in production:\n\n# jnp.put_device(1) results in these C++ calls:\n\n*** Begin stack trace ***\nhttps://github.com/tensorflow/tensorflow/tree/v2.7.0/tensorflow/compiler/xla/python/pytree.cc#L280\nPyTreeDef::Unflatten (py::iterable) const\n  (compiler/xla/python/pytree.cc:280)\n\nhttps://github.com/tensorflow/tensorflow/tree/v2.7.0/tensorflow/compiler/xla/python/pytree.cc#L232\npy::object\nPyTreeDef::UnflattenImpl (py::iterable) const\n  (compiler/xla/python/pytree.cc:232)\n\nhttps://github.com/tensorflow/tensorflow/tree/v2.7.0/tensorflow/compiler/xla/python/py_client.cc#L156\nPyClient::BufferFromPyval (py::handle, PjRtDevice*, bool, PjRtClient::HostBufferSemantics)\n  (compiler/xla/python/py_client.cc:156)\n\nhttps://github.com/tensorflow/tensorflow/tree/v2.7.0/tensorflow/compiler/xla/python/py_values.cc#L250\nDevicePut (py::handle, PjRtDevice*, const DevicePutOptions &)\n  (compiler/xla/python/py_values.cc:150)\n\nhttps://github.com/tensorflow/tensorflow/tree/v2.7.0/tensorflow/compiler/xla/python/py_values.cc#L148\nHandleNumpyArray (py::handle, PjRtDevice*, const DevicePutOptions &)\n  (compiler/xla/python/py_values.cc:148)\n\nhttps://github.com/tensorflow/tensorflow/tree/v2.7.0/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc#L680\nPjRtStreamExecutorClient::BufferFromHostBuffer (const void *, const Shape &, PjRtClient::HostBufferSemantics, function, PjRtDevice*)\n  (compiler/xla/pjrt/pjrt_stream_executor_client.cc:680)\n\nhttps://github.com/tensorflow/tensorflow/tree/v2.7.0/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc#L363\nAllocateDestinationBuffer (const Shape &, PjRtDevice*, LocalDeviceState*, stream_executor::Stream*, bool, PjRtClient*, shared_ptr)\n  (compiler/xla/pjrt/pjrt_stream_executor_client.cc:363)\n\nhttps://github.com/tensorflow/tensorflow/tree/v2.7.0/tensorflow/compiler/xla/service/transfer_manager.cc#L388\nTransferManager::AllocateScopedShapedBuffer (const Shape &, stream_executor::DeviceMemoryAllocator*, int, const fn)\n  (compiler/xla/service/transfer_manager.cc:390)\n\nhttps://github.com/tensorflow/tensorflow/tree/v2.7.0/tensorflow/stream_executor/stream_executor_pimpl.cc#L901\nstream_executor::StreamExecutorMemoryAllocator::Allocate (int, uint64_t, bool, int64_t)\n  (stream_executor/stream_executor_pimpl.cc:901)\n\nhttps://github.com/tensorflow/tensorflow/tree/v2.7.0/tensorflow/stream_executor/stream_executor_pimpl.cc#L487\nstream_executor::StreamExecutor::Allocate (uint64_t, int64_t)\n  (stream_executor/stream_executor_pimpl.cc:487)\n\nhttps://github.com/tensorflow/tensorflow/tree/v2.7.0/tensorflow/stream_executor/tpu/tpu_executor.cc#L195\ntensorflow::tpu::TpuExecutor::Allocate (uint64_t, int64_t)\n  (stream_executor/tpu/tpu_executor.cc:195)\n\nhttps://github.com/tensorflow/tensorflow/tree/v2.7.0/tensorflow/core/platform/default/stacktrace.h#L42\ntensorflow::CurrentStackTrace (bool)\n  (core/platform/default/stacktrace.h:42)\n*** End stack trace ***\n\n\nThis video was also to show off my custom build of iTerm2 (https://github.com/gnachman/iTerm2/pull/454)\nwhich drops me into CLion when I command-click filenames with line numbers. :)\n\n-- Shawn\n@theshawwn on twitter:  https://twitter.com/theshawwn\nshawwn on github: https://github.com/shawwn\nsillysaurusx on hn: https://news.ycombinator.com/threads?id=sillysaurusx"}