<html xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:w="urn:schemas-microsoft-com:office:word" xmlns:m="http://schemas.microsoft.com/office/2004/12/omml" xmlns="http://www.w3.org/TR/REC-html40">
<head>
<meta http-equiv="Content-Type" content="text/html; charset=utf-8">
<meta name="Generator" content="Microsoft Word 15 (filtered medium)">
<style><!--
/* Font Definitions */
@font-face
        {font-family:"Cambria Math";
        panose-1:2 4 5 3 5 4 6 3 2 4;}
@font-face
        {font-family:DengXian;
        panose-1:2 1 6 0 3 1 1 1 1 1;}
@font-face
        {font-family:Calibri;
        panose-1:2 15 5 2 2 2 4 3 2 4;}
@font-face
        {font-family:Aptos;
        panose-1:2 11 0 4 2 2 2 2 2 4;}
@font-face
        {font-family:"\@DengXian";
        panose-1:2 1 6 0 3 1 1 1 1 1;}
/* Style Definitions */
p.MsoNormal, li.MsoNormal, div.MsoNormal
        {margin:0in;
        font-size:12.0pt;
        font-family:"Aptos",sans-serif;}
a:link, span.MsoHyperlink
        {mso-style-priority:99;
        color:blue;
        text-decoration:underline;}
span.EmailStyle19
        {mso-style-type:personal-reply;
        font-family:"Aptos",sans-serif;
        color:windowtext;}
.MsoChpDefault
        {mso-style-type:export-only;
        font-size:10.0pt;
        mso-ligatures:none;}
@page WordSection1
        {size:8.5in 11.0in;
        margin:1.0in 1.0in 1.0in 1.0in;}
div.WordSection1
        {page:WordSection1;}
--></style>
</head>
<body lang="EN-US" link="blue" vlink="purple" style="word-wrap:break-word">
<div class="WordSection1">
<p>Hi Alberto,<br>
<br>
1. To check the array pointer on the PETSc side, you can do print(hex(y_petsc.array.ctypes.data)). Then you will see a pointer mismatch caused by the line
<span style="color:#313131">y = jnp.from_dlpack(y_petsc, copy=False). This is because you configured PETSc in double precision, but JAX uses single precision by default. You can either add</span>
<span style="color:#313131">jax.config.update("jax_enable_x64", True) to make JAX use double precision number or configure PETSc to support single precision.
<br>
<br>
2. Once you fix this precision mismatch, the in-place conversion between PETSc and JAX should work. However, .at[].set() in JAX does not guarantee to operate in-place. The array updates in JAX are generally performed out-of-place by design. You may do the updates
 in PETSc so that it won’t break the zero-copy system.</span><o:p></o:p></p>
<p class="MsoNormal">Hong<o:p></o:p></p>
<p class="MsoNormal"><o:p> </o:p></p>
<div style="border:none;border-top:solid #B5C4DF 1.0pt;padding:3.0pt 0in 0in 0in">
<p class="MsoNormal"><b><span style="font-family:"Calibri",sans-serif;color:black">From:
</span></b><span style="font-family:"Calibri",sans-serif;color:black">petsc-users <petsc-users-bounces@mcs.anl.gov> on behalf of Alberto Cattaneo <bubu.cattaneo@gmail.com><br>
<b>Date: </b>Monday, July 7, 2025 at 8:40 AM<br>
<b>To: </b>"petsc-users@mcs.anl.gov" <petsc-users@mcs.anl.gov><br>
<b>Subject: </b>[petsc-users] Petsc/Jax no copy interfacing issues<o:p></o:p></span></p>
</div>
<div>
<p class="MsoNormal"><o:p> </o:p></p>
</div>
<div>
<div>
<p class="MsoNormal">Greetings.<o:p></o:p></p>
<div>
<p class="MsoNormal">I hope this email reaches you well. I’m trying to get JAX and PETSc to work together in a no-copy system using the DLPack tools in both. Unfortunately I can’t seem to get it to work right. Ideally, I’d like to create a PETSc vec object
 using petsc4py, pass it to to a JAX object without copying, make a change to it in a JAX jitted function and have that change reflected in the PETSc object. All of this without copying.<o:p></o:p></p>
</div>
<div>
<p class="MsoNormal"><o:p> </o:p></p>
</div>
<div>
<p class="MsoNormal">Of note: When I try to do this I get an error that the alignment is wrong and a copy must be made when I call the from-dlpack function but changing the alignment in the PETSc ./config stage to 32 causes the error message to disappear, even
 so it still doesn’t function correctly. I’ve tried looking through the documentation, but I’m getting a little turned around. <o:p></o:p></p>
</div>
<div>
<p class="MsoNormal"><o:p> </o:p></p>
</div>
<div>
<p class="MsoNormal">I’ve included a code snippet below:<o:p></o:p></p>
</div>
<div>
<p class="MsoNormal"><o:p> </o:p></p>
</div>
<div>
<div>
<p style="font-size:1rem;word-spacing:1px"><i><span style="color:#313131">from petsc4py import PETSc as PETSc</span></i><span style="color:#313131"><o:p></o:p></span></p>
<p style="font-size:1rem;word-spacing:1px"><i><span style="color:#313131">import jax</span></i><span style="color:#313131"><o:p></o:p></span></p>
<p style="font-size:1rem;word-spacing:1px"><i><span style="color:#313131">from functools import partial</span></i><span style="color:#313131"><o:p></o:p></span></p>
<p style="font-size:1rem;word-spacing:1px"><i><span style="color:#313131">import jax.numpy as jnp</span></i><span style="color:#313131"><o:p></o:p></span></p>
<p style="word-spacing:1px"><i><span style="color:#313131"> </span></i><span style="color:#313131"><o:p></o:p></span></p>
<p style="font-size:1rem;word-spacing:1px"><i><span style="color:#313131">@partial(jax.jit, donate_argnums=(0,))</span></i><span style="color:#313131"><o:p></o:p></span></p>
<p style="font-size:1rem;word-spacing:1px"><i><span style="color:#313131">def set_in_place(x):</span></i><span style="color:#313131"><o:p></o:p></span></p>
<p style="font-size:1rem;word-spacing:1px"><i><span style="color:#313131">    return <a href="https://urldefense.us/v3/__http:/x.at__;!!G_uCfscf7eWS!cqxG3TobpS7WZAgzxjrlWaxhAiiwWk4i9-WKReIWrc04LoXg4Y8zCkEDYGm_l5GilInGXbyzJWrD3BPRaTPlZHhIdz33$" target="_blank"><span style="color:#4285F4">x.at</span></a>[:].set(3.0)</span></i><span style="color:#313131"><o:p></o:p></span></p>
<p style="word-spacing:1px"><i><span style="color:#313131"> </span></i><span style="color:#313131"><o:p></o:p></span></p>
<p style="font-size:1rem;word-spacing:1px"><i><span style="color:#313131">print('\nTesting jax from_dlpack given a PETSc vector that was allocated by PETSc')</span></i><span style="color:#313131"><o:p></o:p></span></p>
<p style="font-size:1rem;word-spacing:1px"><i><span style="color:#313131">x = jnp.ones((1000,1))</span></i><span style="color:#313131"><o:p></o:p></span></p>
<p style="font-size:1rem;word-spacing:1px"><i><span style="color:#313131">y_petsc = PETSc.Vec().createSeq(x.shape[0])</span></i><span style="color:#313131"><o:p></o:p></span></p>
<p style="font-size:1rem;word-spacing:1px"><i><span style="color:#313131">y_petsc.set(0.0)</span></i><span style="color:#313131"><o:p></o:p></span></p>
<p style="font-size:1rem;word-spacing:1px"><i><span style="color:#313131">print(hex(y_petsc.handle))</span></i><span style="color:#313131"><o:p></o:p></span></p>
<p style="font-size:1rem;word-spacing:1px"><i><span style="color:#313131">y2_petsc = PETSc.Vec().createWithDLPack(y_petsc.toDLPack('rw'))</span></i><span style="color:#313131"><o:p></o:p></span></p>
<p style="font-size:1rem;word-spacing:1px"><i><span style="color:#313131">y2_petsc.set(-1.0)</span></i><span style="color:#313131"><o:p></o:p></span></p>
<p style="font-size:1rem;word-spacing:1px"><i><span style="color:#313131">assert y_petsc.getValue(0) == y2_petsc.getValue(0)</span></i><span style="color:#313131"><o:p></o:p></span></p>
<p style="font-size:1rem;word-spacing:1px"><i><span style="color:#313131">print('After creating a second PETSc vector via a DLPack of the first, modifying the memory of one affects the other.')</span></i><span style="color:#313131"><o:p></o:p></span></p>
<p style="font-size:1rem;word-spacing:1px"><i><span style="color:#313131">#y = jnp.from_dlpack(y_petsc.toDLPack('rw'), copy=False)</span></i><span style="color:#313131"><o:p></o:p></span></p>
<p style="font-size:1rem;word-spacing:1px"><i><span style="color:#313131">y = jnp.from_dlpack(y_petsc, copy=False)</span></i><span style="color:#313131"><o:p></o:p></span></p>
<p style="font-size:1rem;word-spacing:1px"><i><span style="color:#313131">orig_ptr = y.unsafe_buffer_pointer()</span></i><span style="color:#313131"><o:p></o:p></span></p>
<p style="font-size:1rem;word-spacing:1px"><i><span style="color:#313131">print(f'before: ptr at {hex(orig_ptr)}')</span></i><span style="color:#313131"><o:p></o:p></span></p>
<p style="font-size:1rem;word-spacing:1px"><i><span style="color:#313131">y = set_in_place(y)</span></i><span style="color:#313131"><o:p></o:p></span></p>
<p style="font-size:1rem;word-spacing:1px"><i><span style="color:#313131">print(f'after:  ptr at {hex(y.unsafe_buffer_pointer())}')</span></i><span style="color:#313131"><o:p></o:p></span></p>
<p style="font-size:1rem;word-spacing:1px"><i><span style="color:#313131">assert orig_ptr == y.unsafe_buffer_pointer()</span></i><span style="color:#313131"><o:p></o:p></span></p>
<p style="font-size:1rem;word-spacing:1px"><i><span style="color:#313131">#assert y_petsc.getValue(0) == y[0], f'The PETSc value {y_petsc.getValue(0)} did not match the JAX value {y[0]}, so modifying the JAX memory did not affect the PETSc memory.'</span></i><span style="color:#313131"><o:p></o:p></span></p>
<p style="font-size:1rem;word-spacing:1px"><span style="color:#313131"><o:p> </o:p></span></p>
<p style="font-size:1rem;word-spacing:1px"><span style="color:#313131">I’d like the bottom two asserts to pass, but I can only get one of them. If somebody is familiar with this issue I’d greatly appreciate the assistance.<o:p></o:p></span></p>
<p style="font-size:1rem;word-spacing:1px"><span style="color:#313131">Respectfully:<o:p></o:p></span></p>
<p style="font-size:1rem;word-spacing:1px"><span style="color:#313131">Alberto<o:p></o:p></span></p>
</div>
</div>
</div>
</div>
</div>
</body>
</html>