[petsc-users] Petsc/Jax no copy interfacing issues
Zhang, Hong
hongzhang at anl.gov
Tue Jul 8 14:56:08 CDT 2025
Hi Alberto,
1. To check the array pointer on the PETSc side, you can do print(hex(y_petsc.array.ctypes.data)). Then you will see a pointer mismatch caused by the line y = jnp.from_dlpack(y_petsc, copy=False). This is because you configured PETSc in double precision, but JAX uses single precision by default. You can either add jax.config.update("jax_enable_x64", True) to make JAX use double precision number or configure PETSc to support single precision.
2. Once you fix this precision mismatch, the in-place conversion between PETSc and JAX should work. However, .at[].set() in JAX does not guarantee to operate in-place. The array updates in JAX are generally performed out-of-place by design. You may do the updates in PETSc so that it won’t break the zero-copy system.
Hong
From: petsc-users <petsc-users-bounces at mcs.anl.gov> on behalf of Alberto Cattaneo <bubu.cattaneo at gmail.com>
Date: Monday, July 7, 2025 at 8:40 AM
To: "petsc-users at mcs.anl.gov" <petsc-users at mcs.anl.gov>
Subject: [petsc-users] Petsc/Jax no copy interfacing issues
Greetings.
I hope this email reaches you well. I’m trying to get JAX and PETSc to work together in a no-copy system using the DLPack tools in both. Unfortunately I can’t seem to get it to work right. Ideally, I’d like to create a PETSc vec object using petsc4py, pass it to to a JAX object without copying, make a change to it in a JAX jitted function and have that change reflected in the PETSc object. All of this without copying.
Of note: When I try to do this I get an error that the alignment is wrong and a copy must be made when I call the from-dlpack function but changing the alignment in the PETSc ./config stage to 32 causes the error message to disappear, even so it still doesn’t function correctly. I’ve tried looking through the documentation, but I’m getting a little turned around.
I’ve included a code snippet below:
from petsc4py import PETSc as PETSc
import jax
from functools import partial
import jax.numpy as jnp
@partial(jax.jit, donate_argnums=(0,))
def set_in_place(x):
return x.at<https://urldefense.us/v3/__http:/x.at__;!!G_uCfscf7eWS!cqxG3TobpS7WZAgzxjrlWaxhAiiwWk4i9-WKReIWrc04LoXg4Y8zCkEDYGm_l5GilInGXbyzJWrD3BPRaTPlZHhIdz33$>[:].set(3.0)
print('\nTesting jax from_dlpack given a PETSc vector that was allocated by PETSc')
x = jnp.ones((1000,1))
y_petsc = PETSc.Vec().createSeq(x.shape[0])
y_petsc.set(0.0)
print(hex(y_petsc.handle))
y2_petsc = PETSc.Vec().createWithDLPack(y_petsc.toDLPack('rw'))
y2_petsc.set(-1.0)
assert y_petsc.getValue(0) == y2_petsc.getValue(0)
print('After creating a second PETSc vector via a DLPack of the first, modifying the memory of one affects the other.')
#y = jnp.from_dlpack(y_petsc.toDLPack('rw'), copy=False)
y = jnp.from_dlpack(y_petsc, copy=False)
orig_ptr = y.unsafe_buffer_pointer()
print(f'before: ptr at {hex(orig_ptr)}')
y = set_in_place(y)
print(f'after: ptr at {hex(y.unsafe_buffer_pointer())}')
assert orig_ptr == y.unsafe_buffer_pointer()
#assert y_petsc.getValue(0) == y[0], f'The PETSc value {y_petsc.getValue(0)} did not match the JAX value {y[0]}, so modifying the JAX memory did not affect the PETSc memory.'
I’d like the bottom two asserts to pass, but I can only get one of them. If somebody is familiar with this issue I’d greatly appreciate the assistance.
Respectfully:
Alberto
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://lists.mcs.anl.gov/pipermail/petsc-users/attachments/20250708/5584aca9/attachment.html>
More information about the petsc-users
mailing list