[petsc-users] Petsc/Jax no copy interfacing issues
Zhang, Hong
hongzhang at anl.gov
Wed Jul 16 17:15:39 CDT 2025
It is expected that changes made in Jax are not reflected in the PETSc object. The issue has been explained in my previous message (point 2).
Hong
________________________________
From: Alberto Cattaneo <bubu.cattaneo at gmail.com>
Sent: Tuesday, July 15, 2025 1:13 PM
To: Zhang, Hong <hongzhang at anl.gov>
Subject: Re: [petsc-users] Petsc/Jax no copy interfacing issues
Odd, I was using double precision (forgot to include that in the example, sorry) but on my machineI’m still not seeing the changes made reflected in the PETSc object. Are the changes reflected on your end? Is it possibly an ownership issue?
ZjQcmQRYFpfptBannerStart
This Message Is From an External Sender
This message came from outside your organization.
ZjQcmQRYFpfptBannerEnd
Odd, I was using double precision (forgot to include that in the example, sorry) but on my machineI’m still not seeing the changes made reflected in the PETSc object. Are the changes reflected on your end? Is it possibly an ownership issue?
On Tue, Jul 8, 2025 at 3:56 PM Zhang, Hong <hongzhang at anl.gov<mailto:hongzhang at anl.gov>> wrote:
Hi Alberto,
1. To check the array pointer on the PETSc side, you can do print(hex(y_petsc.array.ctypes.data)). Then you will see a pointer mismatch caused by the line y = jnp.from_dlpack(y_petsc, copy=False). This is because you configured PETSc in double precision, but JAX uses single precision by default. You can either add jax.config.update("jax_enable_x64", True) to make JAX use double precision number or configure PETSc to support single precision.
2. Once you fix this precision mismatch, the in-place conversion between PETSc and JAX should work. However, .at[].set() in JAX does not guarantee to operate in-place. The array updates in JAX are generally performed out-of-place by design. You may do the updates in PETSc so that it won’t break the zero-copy system.
Hong
From: petsc-users <petsc-users-bounces at mcs.anl.gov<mailto:petsc-users-bounces at mcs.anl.gov>> on behalf of Alberto Cattaneo <bubu.cattaneo at gmail.com<mailto:bubu.cattaneo at gmail.com>>
Date: Monday, July 7, 2025 at 8:40 AM
To: "petsc-users at mcs.anl.gov<mailto:petsc-users at mcs.anl.gov>" <petsc-users at mcs.anl.gov<mailto:petsc-users at mcs.anl.gov>>
Subject: [petsc-users] Petsc/Jax no copy interfacing issues
Greetings.
I hope this email reaches you well. I’m trying to get JAX and PETSc to work together in a no-copy system using the DLPack tools in both. Unfortunately I can’t seem to get it to work right. Ideally, I’d like to create a PETSc vec object using petsc4py, pass it to to a JAX object without copying, make a change to it in a JAX jitted function and have that change reflected in the PETSc object. All of this without copying.
Of note: When I try to do this I get an error that the alignment is wrong and a copy must be made when I call the from-dlpack function but changing the alignment in the PETSc ./config stage to 32 causes the error message to disappear, even so it still doesn’t function correctly. I’ve tried looking through the documentation, but I’m getting a little turned around.
I’ve included a code snippet below:
from petsc4py import PETSc as PETSc
import jax
from functools import partial
import jax.numpy as jnp
@partial(jax.jit, donate_argnums=(0,))
def set_in_place(x):
return x.at<https://urldefense.us/v3/__https://gcc02.safelinks.protection.outlook.com/?url=https*3A*2F*2Furldefense.us*2Fv3*2F__http*3A*2Fx.at__*3B!!G_uCfscf7eWS!cqxG3TobpS7WZAgzxjrlWaxhAiiwWk4i9-WKReIWrc04LoXg4Y8zCkEDYGm_l5GilInGXbyzJWrD3BPRaTPlZHhIdz33*24&data=05*7C02*7Cpetsc-users*40mcs.anl.gov*7Cd16ee49fac3e47f97dc408ddc4b64c19*7C0cfca18525f749e38ae7704d5326e285*7C0*7C0*7C638883009407911569*7CUnknown*7CTWFpbGZsb3d8eyJFbXB0eU1hcGkiOnRydWUsIlYiOiIwLjAuMDAwMCIsIlAiOiJXaW4zMiIsIkFOIjoiTWFpbCIsIldUIjoyfQ*3D*3D*7C0*7C*7C*7C&sdata=d9kaMxlXfBWQZwzB4IVvYeLz7Ru2LI64jLkBWJZnGHM*3D&reserved=0__;JSUlJSUlJSUlJSUlJSUlJSUlJSUlJSUlJSU!!G_uCfscf7eWS!bFuLQebXmxQbeA2pyUDtq86jpMg21z70S5aqiiqlj5t4fnkwEZVop06kAVrYZpzLLQ82f61tj5fSFWXdKUWRvbcB9Q$ >[:].set(3.0)
print('\nTesting jax from_dlpack given a PETSc vector that was allocated by PETSc')
x = jnp.ones((1000,1))
y_petsc = PETSc.Vec().createSeq(x.shape[0])
y_petsc.set(0.0)
print(hex(y_petsc.handle))
y2_petsc = PETSc.Vec().createWithDLPack(y_petsc.toDLPack('rw'))
y2_petsc.set(-1.0)
assert y_petsc.getValue(0) == y2_petsc.getValue(0)
print('After creating a second PETSc vector via a DLPack of the first, modifying the memory of one affects the other.')
#y = jnp.from_dlpack(y_petsc.toDLPack('rw'), copy=False)
y = jnp.from_dlpack(y_petsc, copy=False)
orig_ptr = y.unsafe_buffer_pointer()
print(f'before: ptr at {hex(orig_ptr)}')
y = set_in_place(y)
print(f'after: ptr at {hex(y.unsafe_buffer_pointer())}')
assert orig_ptr == y.unsafe_buffer_pointer()
#assert y_petsc.getValue(0) == y[0], f'The PETSc value {y_petsc.getValue(0)} did not match the JAX value {y[0]}, so modifying the JAX memory did not affect the PETSc memory.'
I’d like the bottom two asserts to pass, but I can only get one of them. If somebody is familiar with this issue I’d greatly appreciate the assistance.
Respectfully:
Alberto
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://lists.mcs.anl.gov/pipermail/petsc-users/attachments/20250716/cb0e193e/attachment.html>
More information about the petsc-users
mailing list