Coder Social home page Coder Social logo

flaport / klujax Goto Github PK

View Code? Open in Web Editor NEW
31.0 1.0 2.0 97 KB

Solve sparse linear systems in JAX using the KLU algorithm

License: GNU Lesser General Public License v2.1

Python 77.27% C++ 20.38% Makefile 2.35%
autograd jax klu klu-algorithm python solve solver sparse-linear-solver sparse-linear-systems sparse-matrices

klujax's People

Contributors

flaport avatar jan-david-fischbach avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

klujax's Issues

Segmentation Fault KLUSOLVE

Hey there,

I have been trying to use KLUJAX in a topology optimization algorithm, where I need to find the jax.grad of a function called obj_fun which solves a large sparse matrix system using KLUJAX. I am unfortunately receiving a segmentation fault. I have tried increasing the allocated memory and I have attempted to track the fault and found it fails in a JAX function here:

python3.9/site-packages/jax/_src/dispatch.py:895.

Any ideas how to approach this? Thanks!

Best,
Adam

import numpy as np
import jax.numpy as jnp
from jax import value_and_grad
from jax.experimental import sparse as jsparse
from klujax import solve as klusolve
import sys
sys.settrace

def trace(frame, event, arg):
    print("%s, %s:%d" % (event, frame.f_code.co_filename, frame.f_lineno))
    return trace

def main():

  nelx=60
  nely=20
  volfrac=0.3
  rmin=1.5
  penal=6.0

  Emin=1e-9
  Emax=1.0

  ndof = 2*(nelx+1)*(nely+1)

  H, Hs = h_build(nelx, nely, rmin)
  ijK,edofMat = buildstiff_index(nelx,nely)

  x=volfrac * jnp.ones(nely*nelx,dtype=float)
  xold = jnp.copy(x)
  xPhys = jnp.copy(x)
  g=0 # must be initialized to use the NGuyen/Paulino OC approach
  dofs=np.arange(2*(nelx+1)*(nely+1))
  fixed=np.union1d(dofs[0:2*(nely+1):2],np.array([2*(nelx+1)*(nely+1)-1]))
  free=np.setdiff1d(dofs,fixed)
  f=np.zeros((ndof,1))
  f[1,0]=-1
  loop=0
  change=1
  while change>0.01 and loop<50:
    loop=loop+1
    print(loop)

    (obj, dc) = value_and_grad(obj_fun, argnums=0)(xPhys,x,nelx,nely,free,penal,Emax,Emin,fixed,f,H,Hs,edofMat,ijK)

    dv = np.ones(nely*nelx)
    dv[:] = np.asarray(H*(dv[np.newaxis].T/Hs))[:,0]

    xold = jnp.copy(x)
    (x_holder,g)=oc(nelx,nely,x,volfrac,dc,dv,g)
    x = x.at[:].set(x_holder)

    change=np.linalg.norm(x.reshape(nelx*nely,1)-xold.reshape(nelx*nely,1),np.inf)
    # Write iteration history to screen (req. Python 2.6 or newer)
    print("it.: {0} , obj.: {1:.3f} Vol.: {2:.3f}, ch.: {3:.3f}".format(\
          loop,obj,(g+volfrac*nelx*nely)/(nelx*nely),change))

def obj_fun(xPhys,x,nelx,nely,free,penal,Emax,Emin,fixed,f,H,Hs,edofMat,ijK): 
  ndof = 2*(nelx+1)*(nely+1)
  u=np.zeros((ndof,1))
  ce=np.ones(nelx*nely)
  xPhys=np.asarray(jsparse.bcoo_multiply_dense(H,x[np.newaxis].T/Hs))[:]
  K,KE = buildstiff(xPhys,penal,Emax,Emin,ndof,ijK)
  B = f[free, 0]
  u[free,0] = klusolve(K.indices[:,0],K.indices[:,1],K.data,B)
  ce[:] = (np.dot(u[edofMat].reshape(nelx*nely,8),KE) * u[edofMat].reshape(nelx*nely,8) ).sum(1)
  obj=( (Emin+xPhys**penal*(Emax-Emin))*ce )
  return obj.sum()

def buildstiff_index(nelx,nely):
  edofMat=np.zeros((nelx*nely,8),dtype=int)
  for elx in range(nelx):
    for ely in range(nely):
      el = ely+elx*nely
      n1=(nely+1)*elx+ely
      n2=(nely+1)*(elx+1)+ely
      edofMat[el,:]=np.array([2*n1+2, 2*n1+3, 2*n2+2, 2*n2+3,2*n2, 2*n2+1, 2*n1, 2*n1+1])
  ijK = np.vstack((np.kron(edofMat,np.ones((8,1))).flatten().astype(int),np.kron(edofMat,np.ones((1,8))).flatten().astype(int))).T
  return ijK,edofMat

def buildstiff(xPhys,penal,Emax,Emin,ndof,ijK):
  KE=lk()
  sK=((KE.flatten()[np.newaxis]).T*(Emin+(xPhys)**penal*(Emax-Emin))).flatten(order='F')
  K = jsparse.BCOO((sK,ijK),shape=(ndof,ndof))
  return K,KE

def h_build(nelx,nely,rmin):
  nfilter=int(nelx*nely*((2*(np.ceil(rmin)-1)+1)**2))
  ijH = np.zeros((nfilter,2)).astype(int)
  sH = np.zeros(nfilter).astype(int)
  cc=0
  for i in range(nelx):
    for j in range(nely):
      row=i*nely+j
      kk1=int(np.maximum(i-(np.ceil(rmin)-1),0))
      kk2=int(np.minimum(i+np.ceil(rmin),nelx))
      ll1=int(np.maximum(j-(np.ceil(rmin)-1),0))
      ll2=int(np.minimum(j+np.ceil(rmin),nely))
      for k in range(kk1,kk2):
        for l in range(ll1,ll2):
          col=k*nely+l
          fac=rmin-np.sqrt(((i-k)*(i-k)+(j-l)*(j-l)))
          ijH[cc, 0]=row
          ijH[cc, 1]=col
          sH[cc]=np.maximum(0.0,fac)
          cc=cc+1
  H=jsparse.BCOO((sH, ijH), shape=(nelx*nely,nelx*nely))
  Hs=sH.sum().astype(float)
  return H,Hs

def oc(nelx,nely,x,volfrac,dc,dv,g):
	l1=0
	l2=1e9
	move=0.2
	xnew=np.zeros(nelx*nely)
	while (l2-l1)/(l1+l2)>1e-3:
		lmid=0.5*(l2+l1)
		xnew[:]= np.maximum(0.0,np.maximum(x-move,np.minimum(1.0,np.minimum(x+move,x*np.sqrt(-dc/dv/lmid)))))
		gt=g+np.sum((dv*(xnew-x)))
		if gt>0 :
			l1=lmid
		else:
			l2=lmid
	return (xnew,gt)

def deleterowcol(A, delrow, delcol):
	# Assumes that matrix is in symmetric csc form !
	m = A.shape[0]
	keep = np.delete (np.arange(0, m), delrow)
	A = A.at[:].set(A[keep, :])  
	keep = np.delete (np.arange(0, m), delcol)
	A = A.at[:].set(A[keep, :])
	return A   

def lk():
	E=1
	nu=0.3
	k=np.array([1/2-nu/6,1/8+nu/8,-1/4-nu/12,-1/8+3*nu/8,-1/4+nu/12,-1/8-nu/8,nu/6,1/8-3*nu/8])
	KE = E/(1-nu**2)*np.array([ [k[0], k[1], k[2], k[3], k[4], k[5], k[6], k[7]],
	[k[1], k[0], k[7], k[6], k[5], k[4], k[3], k[2]],
	[k[2], k[7], k[0], k[5], k[6], k[3], k[4], k[1]],
	[k[3], k[6], k[5], k[0], k[7], k[2], k[1], k[4]],
	[k[4], k[5], k[6], k[7], k[0], k[1], k[2], k[3]],
	[k[5], k[4], k[3], k[2], k[1], k[0], k[7], k[6]],
	[k[6], k[3], k[4], k[1], k[2], k[7], k[0], k[5]],
	[k[7], k[2], k[1], k[4], k[3], k[6], k[5], k[0]] ]);
	return (KE)

sys.settrace(trace)
main()

Memory Leak

Hi,
I use klujax a lot (and I really like it!), but it seems like there are a few memory leaks.

Consider the following python script:

from math import prod

import jax
import klujax
from jax import numpy as jnp


def random_sparse(key, shape, fill_ratio):
    size = prod(shape)
    ndim = len(shape)
    n = int(fill_ratio * size)

    (key_data, key_indices) = jax.random.split(key, 2)

    data = jax.random.normal(key_data, shape=(n,))
    indices = jnp.floor(jnp.asarray(shape) * jax.random.uniform(key_indices, shape=(n, ndim))).astype(jnp.integer)

    return (data, indices, shape)


n = 50_000

key = jax.random.PRNGKey(0)
(key_A, key_b) = jax.random.split(key, 2)

(data, indices, shape) = random_sparse(key, shape=(n, n), fill_ratio=0.01)
b = jax.random.uniform(key, shape=(n,))

while True:
    x = klujax.solve(indices[:, 0], indices[:, 1], data, b)
    jax.block_until_ready(x)

This script will make your machine go out of memory after just a few seconds.

To fix this, just add these lines to both solve functions in the ccp file:

delete[] Bk;
delete[] Bi;
delete[] Bp;
delete[] Bx;

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.