go-xla

module
v0.1.4 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Dec 17, 2025 License: Apache-2.0

README

go-xla: OpenXLA APIs bindings for Go

GoDev GitHub Go Report Card TestStatus TestStatus TestStatus TestStatus Slack

🎯 Why use go-xla ?

GoMLX Gopher

The go-xla project leverages OpenXLA's to (JIT-) compile, optimize, and accelerate numeric computations (with large data) from Go using various backends supported by OpenXLA: CPU, GPUs (Nvidia, AMD ROCm*, Intel*, Apple Metal*) and TPUs. It can be used to power Machine Learning frameworks (e.g. GoMLX), image processing, scientific computation, game AIs, etc.

And because Jax, TensorFlow and optionally PyTorch run on XLA, it is possible to run Jax functions in Go with go-xla, and probably TensorFlow and PyTorch as well.

The go-xla project aims to be minimalist and robust: it provides well-maintained, extensible Go wrappers for OpenXLA's StableHLO and OpenXLA's PJRT.

The APIs are not very "ergonomic" (error handling everywhere), but it's expected to be a stable building block for other projects to create a friendlier API on top. The same way Jax is a Python friendlier API on top of XLA/PJRT.

One such friendlier API co-developed with go-xla is GoMLX, a Go machine learning framework. But go-xla may be used as a standalone, for lower level access to XLA and other accelerator use cases—like running Jax functions in Go, maybe an "accelerated" image processing or scientific simulation pipeline.

🧭 What is what?

PJRT - "Pretty much Just another RunTime."

It is the heart of the OpenXLA project: it takes an IR (intermediate representation, typically StableHLO) of the "computation graph," JIT (Just-In-Time) compiles it (once) and executes it fast (many times). See the Google's "PJRT: Simplifying ML Hardware and Framework Integration" blog post.

A "computation graph" is the part of your program (usually vectorial math/machine learning related) that one wants to "accelerate."

The PJRT comes in the form of a plugin, a dynamically linked library (.so file in Linux, or optionally .dylib in Darwin, or .dll in Windows). Typically, there is one plugin per hardware you are supporting. E.g.: there are PJRT plugins for CPU (Linux/amd64 and macOS for now, but likely it could be compiled for other CPUs -- SIMD/AVX are well-supported), for TPUs (Google's accelerator), GPUs (Nvidia is well-supported; there are AMD and Intel's PJRT plugins, but they were not tested), and others are in development. Some PJRT plugins are not open-source, but are available for download.

The go-xla project provides the package github.com/gomlx/go-xla/pkg/pjrt, a Go API for dynamically loading and calling the PJRT runtime. It also provides a installer or library (github.com/gomlx/go-xla/pkg/installer) to auto-install (download pre-compiled binaries) PJRT plugins for CPU (from GitHub), CUDA (from pypi.org Jax pacakges) and TPU (also from pypi.org).

StableHLO - "Stable High Level Optimization" (?)

The currently better supported IR (intermediary representation) supported by PJRT, see specs in StableHLO docs. It's a text representation of the computation that can easily be parsed by computers, but not easily written or read by humans.

The package github.com/gomlx/go-xla/pkg/stablehlo provides a Go API for writing StableHLO programs, including shape inference, needed to correctly infer the output shape of operations as the program is being built.

🗺️ How to use it?

Use the stablehlo to define your computation. Then use pjrt to compile (once) and execute it.

github.com/gomlx/go-xla/pkg/stablehlo
  • Create a Builder object with stablehlo.NewBuilder()
  • Create a Main function with Builder.Main() (or other functions with Builder.NewFunction())
  • Define the function using the various operations defined in stablehlo. Their inputs and outputs are *stablehlo.Value types, which hold a reference to the Function they are defined in, as well as their shapes.Shape.
  • Finish functions with Function.Return(values...).
  • Finish the StableHLO program with Builder.Build(), it will return a string with the program, that can be fed to pjrt for compiling and execution.

Here is a sample of the stablehlo Go API, to create a module that calculates $f(x) = x^2+1$ (without the error handling lines):

					builder := stablehlo.New("x_times_x_plus_1") // Use valid identifier for module name
					scalarF32 := shapes.Make(dtypes.F32)         // Scalar float32 shape
					mainFn := builder.Main()
					x, err := mainFn.NamedInput("x", scalarF32)
					fX, err := stablehlo.Multiply(x, x)
					one, err := mainFn.ConstantFromScalar(float32(1))
					fX, err = stablehlo.Add(fX, one)
					err = mainFn.Return(fX) // Set the return value for the main function
					stableHLOCode, err := builder.Build()
					fmt.Printf("StableHLO:\n%s\n", string(stableHLOCode))

Here is the StableHLO text generated:

module @x_times_x_plus_1 {
  func.func @main(%x: tensor<f32>) -> tensor<f32> {
    %0 = "stablehlo.multiply"(%x, %x) : (tensor<f32>, tensor<f32>) -> tensor<f32>
    %1 = "stablehlo.constant"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32>
    %2 = "stablehlo.add"(%0, %1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
    "stablehlo.return"(%2) : (tensor<f32>) -> ()
  }
}
github.com/gomlx/go-xla/pkg/pjrt

The pjrt package includes the following main concepts:

  • Plugin: represents a PJRT plugin. It is created by calling pjrt.GetPlugin(name) (where name is the name of the plugin). It is the main entry point to the PJRT plugin.
  • Client: first thing created after loading a plugin. It seems one can create a singleton Client per plugin, it's not very clear to me why one would create more than one Client.
  • LoadedExecutable: Created when one calls Client.Compile a StableHLO program. The program is compiled and optimized to the PJRT target hardware and made ready to run.
  • Buffer: Represents a buffer with the input/output data for the computations in the accelerators. There are methods to transfer it to/from the host memory. They are the inputs and outputs of LoadedExecutable.Execute.

Example on how to run the computation with pjrt:

var flagPluginName = flag.String("plugin", "cuda", "PRJT plugin name or full path")
...
err := installer.AutoInstall()  // Installs all supported plugin(s) if not already installed.
plugin, err := pjrt.GetPlugin(*flagPluginName)
client, err := plugin.NewClient(nil)
executor, err := client.Compile().WithStableHLO(stableHLOCode).Done()
for ii, value := range []float32{minX, minY, maxX, maxY} {
   inputs[ii], err = pjrt.ScalarToBuffer(m.client, value)
}
outputs, err := m.exec.Execute(inputs...).Done()
flat, err := pjrt.BufferToArray[float32](outputs[0])
outputs[0].Destroy() // Don't wait for the GC, destroy the buffer immediately.
...

For a more elaborate example, see the mandelbrot.ipynb notebook, which generates the image below:

Mandelbrot fractal figure

Installation of PJRT plugin

Most programs may simply add a call installer.AutoInstall() and it will automatically download the PJRT plugin to the user's local home (${HOME}/.local/lib/go-xla/ in Linux, ${HOME}/Library/Application Support/go-xla in MacOS), if not installed already. It also auto-installs Nvidia PJRT plugin and required libraries if it's present.

To manually install it, or if you want a specific version, consider using the command line installer with go run github.com/gomlx/go-xla/cmd/pjrt_installer@latest and follow the self-explanatory menu (or provide the flags for a quiet installation). Or run it with go run github.com/gomlx/go-xla/cmd/pjrt_installer@latest -autoinstall.

🤔 FAQ

  • When is feature X from PJRT going to be supported ? The go-xla project doesn't wrap everything—although it does cover the most common operations. The simple ops and structs are auto-generated. But many require hand-writing. Please, if it is useful to your project, create an issue; I'm happy to add it. I focus on the needs of GoMLX, but the idea is that it can serve other purposes, and I'm happy to support it.

  • Why does PJRT spit out so many logs? Can we disable it? This is a great question ... imagine if every library we use decided they also want to clutter our stderr? I have an open question in Abseil about it. It may be some issue with Abseil Logging which also has this other issue of not allowing two different linked programs/libraries to call its initialization (see Issue #1656). A hacky workaround is duplicating fd 2 and assign to Go's os.Stderr, and then close fd 2, so PJRT plugins won't have where to log. This hack is encoded in the function pjrt.SuppressAbseilLoggingHack(): call it before calling pjrt.GetPlugin. But it may have unintended consequences if some other library depends on the fd 2 to work, or if a real exceptional situation needs to be reported and is not.

🤝 Collaborating or asking for help

Discussion in the Slack channel #gomlx (you can join the slack server here).

Environment Variables

Environment variables that help control or debug how PJRT works:

  • PJRT_PLUGIN_LIBRARY_PATH: Path to search for PJRT plugins. The pjrt package also searches in /usr/local/lib/go-xla, ${HOME}/.local/lib/go-xla, in the standard library paths for the system, and in the paths defined in $LD_LIBRARY_PATH. For compatibility with the older version, it also searches in /usr/local/lib/gomlx/pjrt, ${HOME}/.local/lib/gomlx/pjrt. In MacOS it's the equivalent under ${HOME}/Library/Application Support/.
  • XLA_FLAGS: Used by the C++ PJRT plugins. Documentation is linked by the Jax XLA_FLAGS page, but I found it easier to just set this to "--help" and it prints out the flags.
  • XLA_DEBUG_OPTIONS: If set, it is parsed as a DebugOptions proto that is passed during the JIT-compilation (Client.Compile()) of a computation graph. It is not documented how it works in PJRT (e.g., I observed a great slow down when this is set, even if set to the default values), but the proto has some documentation.

For go-xla developers

💖 Support the Project

If you find this project helpful, please consider supporting our work through GitHub Sponsors.

Your contribution helps us (currently mostly me) maintain our coffee addiction 😃 and dedicate more time to maintenance and add new features for the entire GoMLX ecosystem.

It also helps us acquire access (buying or cloud) to hardware for more portability: e.g.: ROCm, Apple Metal (GPU), Multi-GPU/TPU, AWS Trainium, NVidia DGX Spark, Tenstorrent, etc.

💖 Acknowledgements

This project includes a (slightly modified) copy of the OpenXLA's pjrt_c_api.h file as well as some of the .proto files used by pjrt_c_api.h.

More importantly, we gratefully acknowledge the OpenXLA project and team for their valuable work in developing and maintaining these plugins.

For more information about OpenXLA, please visit their website at openxla.org, or the GitHub page at github.com/openxla/xla

⚖️ Licensing

Copyright 2025 Jan Pfeifer

The go-xla project is licensed under the Apache 2.0 license.

The OpenXLA project, including pjrt_c_api.h file, the CPU and CUDA plugins, is licensed under the Apache 2.0 license.

The CUDA plugin also uses the Nvidia CUDA Toolkit, which is subject to Nvidia's licensing terms and must be installed by the user or at the user's request.

Directories

Path Synopsis
cmd
cbuffer
Package cbuffer provides a wrapper for a C/C++ buffer that can be used to transfer data in-between pjrt, xlabuilder and the user of the library.
Package cbuffer provides a wrapper for a C/C++ buffer that can be used to transfer data in-between pjrt, xlabuilder and the user of the library.
cmd/dtypes_codegen command
codegen parses the pjrt_c_api.h and generates boilerplate code for creating the various C structures.
codegen parses the pjrt_c_api.h and generates boilerplate code for creating the various C structures.
cmd/exec_hlo command
exec_hlo is a trivial testing program to execute HLO programs that take as input only one value.
exec_hlo is a trivial testing program to execute HLO programs that take as input only one value.
cmd/pjrt_codegen command
pjrt_codegen copies prjt_c_api.h from github.com/openxla/xla source (pointed by XLA_SRC env variable), parses it and generates boilerplate code for creating the various C structures.
pjrt_codegen copies prjt_c_api.h from github.com/openxla/xla source (pointed by XLA_SRC env variable), parses it and generates boilerplate code for creating the various C structures.
cmd/protoc_xla_protos command
protoc_xla_protos compiles the .proto from the OpenXLA/XLA sources to subpackages of "github.com/gomlx/go-xla/internal/protos".
protoc_xla_protos compiles the .proto from the OpenXLA/XLA sources to subpackages of "github.com/gomlx/go-xla/internal/protos".
cmd/xlabuilder_codegen command
codegen parses the node_types.txt and generates boilerplate code both C and Go.
codegen parses the node_types.txt and generates boilerplate code both C and Go.
must
Package must provide a set of functions that check for errors and panic on error.
Package must provide a set of functions that check for errors and panic on error.
optypes
Package optypes defines OpType and lists the supported operations.
Package optypes defines OpType and lists the supported operations.
pjrt/cpudynamictest
Package cpudynamictest is just a hack around Go's limitation to use CGO in tests and to avoid cyclic dependency.
Package cpudynamictest is just a hack around Go's limitation to use CGO in tests and to avoid cyclic dependency.
pjrt/cpustatictest
Package cpustatictest is just a hack around Go's limitation to use CGO in tests and to avoid cyclic dependency.
Package cpustatictest is just a hack around Go's limitation to use CGO in tests and to avoid cyclic dependency.
protos
Package protos is empty, it simply include a rule to generate all the sub-packages: one sub-package per XLA proto used in gopjrt.
Package protos is empty, it simply include a rule to generate all the sub-packages: one sub-package per XLA proto used in gopjrt.
shapeinference
Package shapeinference calculates the shape resulting from operations and validates its inputs.
Package shapeinference calculates the shape resulting from operations and validates its inputs.
utils
Package utils holds small utility types and functions used internally in stablehlo.
Package utils holds small utility types and functions used internally in stablehlo.
pkg
installer
Package installer provides functionality to install PJRT plugins.
Package installer provides functionality to install PJRT plugins.
pjrt
Package pjrt implements a Go wrapper for the PJRT_C_API.
Package pjrt implements a Go wrapper for the PJRT_C_API.
pjrt/cpu/dynamic
Package dynamic will link (preload) a dynamically loaded library `libpjrt_c_api_cpu_dynamic`, that is used if the user requests a "cpu" plugin.
Package dynamic will link (preload) a dynamically loaded library `libpjrt_c_api_cpu_dynamic`, that is used if the user requests a "cpu" plugin.
pjrt/cpu/static
Package static statically links a CPU PJRT plugin, and registers with the name "cpu".
Package static statically links a CPU PJRT plugin, and registers with the name "cpu".
stablehlo
Package stablehlo helps build a ToStableHLO program (text format) to then be JIT-compiled and executed by PJRT (github.com/gomlx/go-xla/pkg/pjrt).
Package stablehlo helps build a ToStableHLO program (text format) to then be JIT-compiled and executed by PJRT (github.com/gomlx/go-xla/pkg/pjrt).
types/dtypes/bfloat16
Package bfloat16 is a trivial implementation for the bfloat16 type, based on https://github.com/x448/float16 and the pending issue in https://github.com/x448/float16/issues/22
Package bfloat16 is a trivial implementation for the bfloat16 type, based on https://github.com/x448/float16 and the pending issue in https://github.com/x448/float16/issues/22
types/shapes
Package shapes defines Shape and DType and associated tools.
Package shapes defines Shape and DType and associated tools.
types/shardy
Package shardy provides the types needed to define a distributed computation topology.
Package shardy provides the types needed to define a distributed computation topology.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL