Coder Social home page Coder Social logo

Comments (4)

relf avatar relf commented on September 1, 2024 1

@kanekosh I used your script to create a new test and thanks to @sdubreui and @NatOnera the PR #188 fixes the bug.

from smt.

relf avatar relf commented on September 1, 2024

Thanks for this thorough report, definitely there is a problem with derivatives and linear option. I'll try to take a look...

from smt.

dmichalis avatar dmichalis commented on September 1, 2024

**Not sure if it is my fault, but i get the following error whenever i try to call predict_derivatives in the wing wight example :
File "/home/dimitris/SMT/smt-master/smt/surrogate_models/krg_based.py", line 348, in _predict_derivatives
/ self.X_std[kx]
IndexError: index 2 is out of bounds for axis 0 with size 2

I get the same error using every surrogate model that uses krg_based.py. The script i use is the following :

from__future__ import print_function, division`
import numpy as np
from scipy import linalg
from smt.utils import compute_rms_error


from smt.problems import  WingWeight
from smt.sampling_methods import LHS, Random
from smt.surrogate_models import LS, QP, KPLS, KRG, KPLSK, GEKPLS
try:
    from smt.surrogate_models import IDW, RBF, RMTC, RMTB
    compiled_available = True
except:
    compiled_available = False

try:
    import matplotlib.pyplot as plt
    plot_status = True
except:
    plot_status = False

import scipy.interpolate

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
from matplotlib import cm
import pylab

########### Initialization of the problem, construction of the training and validation points

ndim = 10 #number of variables
ndoe = 50 


#Construction of the DOE for the fuction
fun = WingWeight(ndim = ndim)
sampling = LHS(xlimits = fun.xlimits, criterion = 'cm')
xt = sampling(ndoe)

#Compute the outputs
yt = fun(xt)    



# Construction of the validation points
ntest = 100
  
sampling = Random(xlimits = fun.xlimits)
xtest = sampling(ntest)
ytest = fun(xtest)



#plot of validation vs sampling points
fig = plt.figure(1)
ax = fig.add_subplot(2, 1, 1)
ax.plot(yt, yt, "o" , label = 'sampling points')
ax.plot(ytest, ytest, 'r.', label = 'validation points')

ax.set_xlabel('y')
ax.set_ylabel('y')

ax.legend(loc='upper left')
plt.title('Plot of validation vs sampling points')
#ax.set_xscale('log')
#ax.set_yscale('log')
ax.grid()
pylab.show()



#-------------------------- The KPLSK model--------------------------------#   
t = KPLSK(n_comp = 10, theta0 = [1e-2]*ndim ,print_prediction = False)
t.set_training_values(xt,yt)

t.train()


# Prediction of the validation points
y = t.predict_values(xtest)
#print('predicted values :', y)
print('KPLSK model,  err: '+ str(compute_rms_error(t,xtest,ytest)))
if plot_status:
    
# Plot the function, the prediction and the 95% confidence interval based on
# the MSE
    fig = plt.figure()
    plt.plot(ytest, ytest, '-', label='$y_{true}$')
    plt.plot(ytest, y, 'r.', label='$\hat{y}$')
   
    plt.xlabel('$y_{true}$')
    plt.ylabel('$\hat{y}$')
    
    plt.legend(loc='upper left')
    plt.title('KPLSK model: validation of the prediction model')   

if plot_status:
    plt.show()

# Value of theta
print("theta values",  t.optimal_theta)


#Computation of the gradient for the DOE points
# Compute the gradient
for i in range(ndim):
    yd = fun(xt, kx = i)

ydtest = np.zeros((ntest,ndim))   
for i in range(ndim):
    ydtest[:,i] = fun(xtest,kx = i).T


# Prediction of the derivatives with regards to each direction space
yd_prediction = np.zeros((ntest,ndim))     
for i in range(ndim):
    yd_prediction[:,i] = t.predict_derivatives(xtest, kx = i).T
    print('KPLSK model, err of the '+str(i+1)+'-th derivative: '+ str(compute_rms_error(t,xtest,ydtest[:,i],kx=i)))

    if plot_status:
        
        plt.plot(ydtest[:,i],ydtest[:,i],'-.')
        plt.plot(ydtest[:,i],yd_prediction[:,i],'.')
        plt.title('KPLSK model, prediction of the '+ str(i+1)+'-th derivate')

    if plot_status:
        plt.show()

from smt.

relf avatar relf commented on September 1, 2024

You're right there is still a bug here when poly is 'constant' and ndim > 2. I will push a fix.

from smt.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.