Skip to content

Commit

Permalink
pyx nocopy
Browse files Browse the repository at this point in the history
  • Loading branch information
fredyshox committed Aug 31, 2024
1 parent 05fd602 commit 6879fa7
Showing 1 changed file with 31 additions and 12 deletions.
43 changes: 31 additions & 12 deletions rednose/helpers/ekf_sym_pyx.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,40 @@ cdef np.ndarray[np.float64_t, ndim=2, mode="c"] matrix_to_numpy(MatrixXdr arr):
cdef double[:,:] mem_view = <double[:arr.rows(),:arr.cols()]>arr.data()
return np.copy(np.asarray(mem_view, dtype=np.double, order="C"))

@cython.wraparound(False)
@cython.boundscheck(False)
cdef np.ndarray[np.float64_t, ndim=2, mode="c"] matrix_to_numpy_nocopy(MatrixXdr arr):
cdef double[:,:] mem_view = <double[:arr.rows(),:arr.cols()]>arr.data()
return np.asarray(mem_view, dtype=np.double, order="C")

@cython.wraparound(False)
@cython.boundscheck(False)
cdef np.ndarray[np.float64_t, ndim=1, mode="c"] vector_to_numpy(VectorXd arr):
cdef double[:] mem_view = <double[:arr.rows()]>arr.data()
return np.copy(np.asarray(mem_view, dtype=np.double, order="C"))

@cython.wraparound(False)
@cython.boundscheck(False)
cdef np.ndarray[np.float64_t, ndim=1, mode="c"] vector_to_numpy_nocopy(VectorXd arr):
cdef double[:] mem_view = <double[:arr.rows()]>arr.data()
return np.asarray(mem_view, dtype=np.double, order="C")


cdef class Result:
cdef Estimate estimate

@property
def y(self):
return vector_to_numpy_nocopy(self.estimate.y.at(0))

@property
def xk(self):
return vector_to_numpy_nocopy(self.estimate.xk)

@property
def Pk(self):
return matrix_to_numpy_nocopy(self.estimate.Pk)

cdef class EKF_sym_pyx:
cdef EKFSym* ekf
def __cinit__(self, str gen_dir, str name, np.ndarray[np.float64_t, ndim=2] Q,
Expand Down Expand Up @@ -166,18 +194,9 @@ cdef class EKF_sym_pyx:
if not res.has_value():
return None

cdef VectorXd tmpvec
return (
vector_to_numpy(res.value().xk1),
vector_to_numpy(res.value().xk),
matrix_to_numpy(res.value().Pk1),
matrix_to_numpy(res.value().Pk),
res.value().t,
res.value().kind,
[vector_to_numpy(tmpvec) for tmpvec in res.value().y],
z, # TODO: take return values?
extra_args,
)
result = Result()
result.estimate = res.value()
return result

def augment(self):
raise NotImplementedError() # TODO
Expand Down

0 comments on commit 6879fa7

Please sign in to comment.