<div><div>Greetings.<div dir="auto">I hope this email reaches you well. I’m trying to get JAX and PETSc to work together in a no-copy system using the DLPack tools in both. Unfortunately I can’t seem to get it to work right. Ideally, I’d like to create a PETSc vec object using petsc4py, pass it to to a JAX object without copying, make a change to it in a JAX jitted function and have that change reflected in the PETSc object. All of this without copying.</div><div dir="auto"><br></div><div dir="auto">Of note: When I try to do this I get an error that the alignment is wrong and a copy must be made when I call the from-dlpack function but changing the alignment in the PETSc ./config stage to 32 causes the error message to disappear, even so it still doesn’t function correctly. I’ve tried looking through the documentation, but I’m getting a little turned around. </div><div dir="auto"><br></div><div dir="auto">I’ve included a code snippet below:</div><div dir="auto"><br></div><div dir="auto"><div><p style="font-size:1rem;font-weight:400;letter-spacing:normal;text-indent:0px;text-transform:none;white-space:normal;word-spacing:1px;text-decoration:none;color:rgb(49,49,49)"><i>from petsc4py import PETSc as PETSc<u></u><u></u></i></p><p style="font-size:1rem;font-weight:400;letter-spacing:normal;text-indent:0px;text-transform:none;white-space:normal;word-spacing:1px;text-decoration:none;color:rgb(49,49,49)"><i>import jax<u></u><u></u></i></p><p style="font-size:1rem;font-weight:400;letter-spacing:normal;text-indent:0px;text-transform:none;white-space:normal;word-spacing:1px;text-decoration:none;color:rgb(49,49,49)"><i>from functools import partial<u></u><u></u></i></p><p style="font-size:1rem;font-weight:400;letter-spacing:normal;text-indent:0px;text-transform:none;white-space:normal;word-spacing:1px;text-decoration:none;color:rgb(49,49,49)"><i>import jax.numpy as jnp<u></u><u></u></i></p><p style="font-size:16px;font-weight:400;letter-spacing:normal;text-indent:0px;text-transform:none;white-space:normal;word-spacing:1px;text-decoration:none;color:rgb(49,49,49)"><i><u></u> <u></u></i></p><p style="font-size:1rem;font-weight:400;letter-spacing:normal;text-indent:0px;text-transform:none;white-space:normal;word-spacing:1px;text-decoration:none;color:rgb(49,49,49)"><i>@partial(jax.jit, donate_argnums=(0,))<u></u><u></u></i></p><p style="font-size:1rem;font-weight:400;letter-spacing:normal;text-indent:0px;text-transform:none;white-space:normal;word-spacing:1px;text-decoration:none;color:rgb(49,49,49)"><i>def set_in_place(x):<u></u><u></u></i></p><p style="font-size:1rem;font-weight:400;letter-spacing:normal;text-indent:0px;text-transform:none;white-space:normal;word-spacing:1px;text-decoration:none;color:rgb(49,49,49)"><i>    return <a href="https://urldefense.us/v3/__http://x.at__;!!G_uCfscf7eWS!cqxG3TobpS7WZAgzxjrlWaxhAiiwWk4i9-WKReIWrc04LoXg4Y8zCkEDYGm_l5GilInGXbyzJWrD3BPRaTPlZHhIdz33$" style="font-size:1rem;color:rgb(66,133,244)" target="_blank">x.at</a>[:].set(3.0)<u></u><u></u></i></p><p style="font-size:16px;font-weight:400;letter-spacing:normal;text-indent:0px;text-transform:none;white-space:normal;word-spacing:1px;text-decoration:none;color:rgb(49,49,49)"><i><u></u> <u></u></i></p><p style="font-size:1rem;font-weight:400;letter-spacing:normal;text-indent:0px;text-transform:none;white-space:normal;word-spacing:1px;text-decoration:none;color:rgb(49,49,49)"><i>print('\nTesting jax from_dlpack given a PETSc vector that was allocated by PETSc')<u></u><u></u></i></p><p style="font-size:1rem;font-weight:400;letter-spacing:normal;text-indent:0px;text-transform:none;white-space:normal;word-spacing:1px;text-decoration:none;color:rgb(49,49,49)"><i>x = jnp.ones((1000,1))<u></u><u></u></i></p><p style="font-size:1rem;font-weight:400;letter-spacing:normal;text-indent:0px;text-transform:none;white-space:normal;word-spacing:1px;text-decoration:none;color:rgb(49,49,49)"><i>y_petsc = PETSc.Vec().createSeq(x.shape[0])<u></u><u></u></i></p><p style="font-size:1rem;font-weight:400;letter-spacing:normal;text-indent:0px;text-transform:none;white-space:normal;word-spacing:1px;text-decoration:none;color:rgb(49,49,49)"><i>y_petsc.set(0.0)<u></u><u></u></i></p><p style="font-size:1rem;font-weight:400;letter-spacing:normal;text-indent:0px;text-transform:none;white-space:normal;word-spacing:1px;text-decoration:none;color:rgb(49,49,49)"><i>print(hex(y_petsc.handle))<u></u><u></u></i></p><p style="font-size:1rem;font-weight:400;letter-spacing:normal;text-indent:0px;text-transform:none;white-space:normal;word-spacing:1px;text-decoration:none;color:rgb(49,49,49)"><i>y2_petsc = PETSc.Vec().createWithDLPack(y_petsc.toDLPack('rw'))<u></u><u></u></i></p><p style="font-size:1rem;font-weight:400;letter-spacing:normal;text-indent:0px;text-transform:none;white-space:normal;word-spacing:1px;text-decoration:none;color:rgb(49,49,49)"><i>y2_petsc.set(-1.0)<u></u><u></u></i></p><p style="font-size:1rem;font-weight:400;letter-spacing:normal;text-indent:0px;text-transform:none;white-space:normal;word-spacing:1px;text-decoration:none;color:rgb(49,49,49)"><i>assert y_petsc.getValue(0) == y2_petsc.getValue(0)<u></u><u></u></i></p><p style="font-size:1rem;font-weight:400;letter-spacing:normal;text-indent:0px;text-transform:none;white-space:normal;word-spacing:1px;text-decoration:none;color:rgb(49,49,49)"><i>print('After creating a second PETSc vector via a DLPack of the first, modifying the memory of one affects the other.')<u></u><u></u></i></p><p style="font-size:1rem;font-weight:400;letter-spacing:normal;text-indent:0px;text-transform:none;white-space:normal;word-spacing:1px;text-decoration:none;color:rgb(49,49,49)"><i>#y = jnp.from_dlpack(y_petsc.toDLPack('rw'), copy=False)<u></u><u></u></i></p><p style="font-size:1rem;font-weight:400;letter-spacing:normal;text-indent:0px;text-transform:none;white-space:normal;word-spacing:1px;text-decoration:none;color:rgb(49,49,49)"><i>y = jnp.from_dlpack(y_petsc, copy=False)<u></u><u></u></i></p><p style="font-size:1rem;font-weight:400;letter-spacing:normal;text-indent:0px;text-transform:none;white-space:normal;word-spacing:1px;text-decoration:none;color:rgb(49,49,49)"><i>orig_ptr = y.unsafe_buffer_pointer()<u></u><u></u></i></p><p style="font-size:1rem;font-weight:400;letter-spacing:normal;text-indent:0px;text-transform:none;white-space:normal;word-spacing:1px;text-decoration:none;color:rgb(49,49,49)"><i>print(f'before: ptr at {hex(orig_ptr)}')<u></u><u></u></i></p><p style="font-size:1rem;font-weight:400;letter-spacing:normal;text-indent:0px;text-transform:none;white-space:normal;word-spacing:1px;text-decoration:none;color:rgb(49,49,49)"><i>y = set_in_place(y)<u></u><u></u></i></p><p style="font-size:1rem;font-weight:400;letter-spacing:normal;text-indent:0px;text-transform:none;white-space:normal;word-spacing:1px;text-decoration:none;color:rgb(49,49,49)"><i>print(f'after:  ptr at {hex(y.unsafe_buffer_pointer())}')<u></u><u></u></i></p><p style="font-size:1rem;font-weight:400;letter-spacing:normal;text-indent:0px;text-transform:none;white-space:normal;word-spacing:1px;text-decoration:none;color:rgb(49,49,49)"><i>assert orig_ptr == y.unsafe_buffer_pointer()<u></u><u></u></i></p><p style="font-size:1rem;font-weight:400;letter-spacing:normal;text-indent:0px;text-transform:none;white-space:normal;word-spacing:1px;text-decoration:none;color:rgb(49,49,49)"><i>#assert y_petsc.getValue(0) == y[0], f'The PETSc value {y_petsc.getValue(0)} did not match the JAX value {y[0]}, so modifying the JAX memory did not affect the PETSc memory.'</i></p><p style="font-size:1rem;font-style:normal;font-weight:400;letter-spacing:normal;text-indent:0px;text-transform:none;white-space:normal;word-spacing:1px;text-decoration:none;color:rgb(49,49,49)"><br></p><p style="font-size:1rem;font-style:normal;font-weight:400;letter-spacing:normal;text-indent:0px;text-transform:none;white-space:normal;word-spacing:1px;text-decoration:none;color:rgb(49,49,49)" dir="auto">I’d like the bottom two asserts to pass, but I can only get one of them. If somebody is familiar with this issue I’d greatly appreciate the assistance.</p><p style="font-size:1rem;font-style:normal;font-weight:400;letter-spacing:normal;text-indent:0px;text-transform:none;white-space:normal;word-spacing:1px;text-decoration:none;color:rgb(49,49,49)" dir="auto">Respectfully:</p><p style="font-size:1rem;font-style:normal;font-weight:400;letter-spacing:normal;text-indent:0px;text-transform:none;white-space:normal;word-spacing:1px;text-decoration:none;color:rgb(49,49,49)" dir="auto">Alberto</p></div></div>
</div>
</div>