changeset 3991:48d2bc4a3729

[project @ 2002-07-16 17:46:50 by jwe]
author jwe
date Tue, 16 Jul 2002 17:46:51 +0000
parents 46388d6a4e44
children 53b4eab68976
files liboctave/ChangeLog liboctave/DAEFunc.h liboctave/DASPK.cc liboctave/DASRT.cc liboctave/DASSL.cc liboctave/ODESSA.cc src/ChangeLog src/DLD-FUNCTIONS/dassl.cc src/DLD-FUNCTIONS/lsode.cc
diffstat 9 files changed, 202 insertions(+), 112 deletions(-) [+]
line wrap: on
line diff
--- a/liboctave/ChangeLog
+++ b/liboctave/ChangeLog
@@ -1,5 +1,9 @@
 2002-07-16  John W. Eaton  <jwe@bevo.che.wisc.edu>
 
+	* DAEFunc.h (DAEFunc): Jacobian function now follows format of DASSL.
+	* DASSL.cc (ddassl_j): Make it work.
+	* DASPK.cc (ddaspk_j): Likewise.
+
 	* DAE.cc: Delete.
 
 	* DAERT.h, DAERTFunc.h, DASRT.h, DASRT.cc: New files for DAE
--- a/liboctave/DAEFunc.h
+++ b/liboctave/DAEFunc.h
@@ -31,18 +31,17 @@
 {
 public:
 
-  struct DAEJac
-    {
-      Matrix *dfdxdot;
-      Matrix *dfdx;
-    };
-
   typedef ColumnVector (*DAERHSFunc) (const ColumnVector& x,
 				      const ColumnVector& xdot,
 				      double t, int& ires);
 
-  typedef DAEJac (*DAEJacFunc) (const ColumnVector& x,
-				const ColumnVector& xdot, double t);
+  // This is really the form used by DASSL:
+  //
+  //   PD = DG/DY + CJ * DG/DYPRIME
+
+  typedef Matrix (*DAEJacFunc) (const ColumnVector& x,
+				const ColumnVector& xdot,
+				double t, double cj);
 
   DAEFunc (void)
     : fun (0), jac (0) { }
--- a/liboctave/DASPK.cc
+++ b/liboctave/DASPK.cc
@@ -175,31 +175,27 @@
   return 0;
 }
 
+
 int
-ddaspk_j (const double& time, const double *, const double *, double *pd,
-	  const double& cj, double *, int *)
+ddaspk_j (const double& time, const double *state, const double *deriv,
+	  double *pd, const double& cj, double *, int *)
 {
+  // XXX FIXME XXX -- would be nice to avoid copying the data.
+
   ColumnVector tmp_state (nn);
   ColumnVector tmp_deriv (nn);
 
-  // XXX FIXME XXX
-
-  Matrix tmp_dfdxdot (nn, nn);
-  Matrix tmp_dfdx (nn, nn);
+  for (int i = 0; i < nn; i++)
+    {
+      tmp_deriv.elem (i) = deriv [i];
+      tmp_state.elem (i) = state [i];
+    }
 
-  DAEFunc::DAEJac tmp_jac;
-  tmp_jac.dfdxdot = &tmp_dfdxdot;
-  tmp_jac.dfdx    = &tmp_dfdx;
-
-  tmp_jac = user_jac (tmp_state, tmp_deriv, time);
-
-  // Fix up the matrix of partial derivatives for daspk.
-
-  tmp_dfdx = tmp_dfdx + cj * tmp_dfdxdot;
+  Matrix tmp_pd = user_jac (tmp_state, tmp_deriv, time, cj);
 
   for (int j = 0; j < nn; j++)
     for (int i = 0; i < nn; i++)
-      pd [nn * j + i] = tmp_dfdx.elem (i, j);
+      pd [nn * j + i] = tmp_pd.elem (i, j);
 
   return 0;
 }
--- a/liboctave/DASRT.cc
+++ b/liboctave/DASRT.cc
@@ -45,10 +45,6 @@
 #include "utils.h"
 #include "variables.h"
 
-// For instantiating the Array<Matrix> object.
-#include "Array.h"
-#include "Array.cc"
-
 #include "DASRT.h"
 #include "f77-fcn.h"
 #include "lo-error.h"
@@ -57,23 +53,22 @@
 #define F77_FUNC(x, X) F77_FCN (x, X)
 #endif
 
+typedef int (*dasrt_fcn_ptr) (const double&, const double*, const double*,
+			      double*, int&, double*, int*);
+
+typedef int (*dasrt_jac_ptr) (const double&, const double*, const double*,
+			      double*, const double&, double*, int*);
+
+typedef int (*dasrt_constr_ptr) (const int&, const double&, const double*,
+				 const int&, double*, double*, int*);
+
 extern "C"
-{
-  int F77_FUNC (ddasrt, DASRT) (int (*)(const double&, double*, double*,
-					double*, int&, double*, int*),
-				const int&, const double&, double*, double*,
-				const double&, int*, double*, double*,
-				int&, double*, const int&, int*, 
-				const int&, double*, int*,
-				int (*)(const double&, double*,
-					double*, double*,
-					const double&, double*, int*),
-				int (*)(const int&, const double&, double*,
-					const int&, double*, double*, int*),
-				const int&, int*);
-}
-
-template class Array<Matrix>;
+int F77_FUNC (ddasrt, DASRT) (dasrt_fcn_ptr, const int&, double&,
+			      double*, double*, const double&, int*,
+			      double*, double*, int&, double*,
+			      const int&, int*, const int&, double*,
+			      int*, dasrt_jac_ptr, dasrt_constr_ptr,
+			      const int&, int*);
 
 static DAEFunc::DAERHSFunc user_fsub;
 static DAEFunc::DAEJacFunc user_jsub;
@@ -81,8 +76,8 @@
 static int nn;
 
 static int
-ddasrt_f (const double& t, double *state, double *deriv, double *delta,
-          int& ires, double *rpar, int *ipar)
+ddasrt_f (const double& t, const double *state, const double *deriv,
+	  double *delta, int& ires, double *rpar, int *ipar)
 {
   ColumnVector tmp_state (nn);
   for (int i = 0; i < nn; i++)
@@ -111,43 +106,33 @@
 
 //static efptr e_fun;
 
-static int
-ddasrt_j (const double& t, double *state, double *deriv,
-	  double *pdwork, const double& cj, double *rpar, int *ipar) 
+int
+ddasrt_j (const double& time, const double *state, const double *deriv,
+	  double *pd, const double& cj, double *, int *)
 {
-  ColumnVector tmp_state (nn);
-  for (int i = 0; i < nn; i++)
-    tmp_state(i) = state[i];
-
-  ColumnVector tmp_deriv (nn);
-  for (int i = 0; i < nn; i++)
-    tmp_deriv(i) = deriv[i];
+  // XXX FIXME XXX -- would be nice to avoid copying the data.
 
-  // XXX FIXME XXX
-
-  Matrix tmp_dfdxdot (nn, nn);
-  Matrix tmp_dfdx (nn, nn);
+  ColumnVector tmp_state (nn);
+  ColumnVector tmp_deriv (nn);
 
-  DAEFunc::DAEJac tmp_jac;
-  tmp_jac.dfdxdot = &tmp_dfdxdot;
-  tmp_jac.dfdx    = &tmp_dfdx;
+  for (int i = 0; i < nn; i++)
+    {
+      tmp_deriv.elem (i) = deriv [i];
+      tmp_state.elem (i) = state [i];
+    }
 
-  tmp_jac = user_jsub (tmp_state, tmp_deriv, t);
-
-  // Fix up the matrix of partial derivatives for dasrt.
-
-  tmp_dfdx = tmp_dfdx + cj * tmp_dfdxdot;
+  Matrix tmp_pd = user_jsub (tmp_state, tmp_deriv, time, cj);
 
   for (int j = 0; j < nn; j++)
     for (int i = 0; i < nn; i++)
-      pdwork[j*nn+i] = tmp_dfdx.elem (i, j);
+      pd [nn * j + i] = tmp_pd.elem (i, j);
 
   return 0;
 }
 
 static int
-ddasrt_g (const int& neq, const double& t, double *state, const int& ng,
-	  double *gout, double *rpar, int *ipar) 
+ddasrt_g (const int& neq, const double& t, const double *state,
+	  const int& ng, double *gout, double *rpar, int *ipar) 
 {
   int n = neq;
 
--- a/liboctave/DASSL.cc
+++ b/liboctave/DASSL.cc
@@ -35,19 +35,19 @@
 #include "f77-fcn.h"
 #include "lo-error.h"
 
-typedef int (*dassl_fcn_ptr) (const double&, double*, double*,
+typedef int (*dassl_fcn_ptr) (const double&, const double*, const double*,
 			      double*, int&, double*, int*);
 
-typedef int (*dassl_jac_ptr) (const double&, double*, double*,
+typedef int (*dassl_jac_ptr) (const double&, const double*, const double*,
 			      double*, const double&, double*, int*);
 
 extern "C"
 int F77_FUNC (ddassl, DDASSL) (dassl_fcn_ptr, const int&, double&,
-			      double*, double*, double&, const int*,
-			      const double&, const double&, int&,
-			      double*, const int&, int*, const int&,
-			      const double*, const int*,
-			      dassl_jac_ptr);
+			       double*, double*, double&, const int*,
+			       const double&, const double&, int&,
+			       double*, const int&, int*, const int&,
+			       const double*, const int*,
+			       dassl_jac_ptr);
 
 static DAEFunc::DAERHSFunc user_fun;
 static DAEFunc::DAEJacFunc user_jac;
@@ -132,9 +132,11 @@
 }
 
 int
-ddassl_f (const double& time, double *state, double *deriv,
+ddassl_f (const double& time, const double *state, const double *deriv,
 	  double *delta, int& ires, double *, int *)
 {
+  // XXX FIXME XXX -- would be nice to avoid copying the data.
+
   ColumnVector tmp_deriv (nn);
   ColumnVector tmp_state (nn);
   ColumnVector tmp_delta (nn);
@@ -162,30 +164,25 @@
 }
 
 int
-ddassl_j (const double& time, double *, double *, double *pd, const
-	  double& cj, double *, int *)
+ddassl_j (const double& time, const double *state, const double *deriv,
+	  double *pd, const double& cj, double *, int *)
 {
+  // XXX FIXME XXX -- would be nice to avoid copying the data.
+
   ColumnVector tmp_state (nn);
   ColumnVector tmp_deriv (nn);
 
-  // XXX FIXME XXX
-
-  Matrix tmp_dfdxdot (nn, nn);
-  Matrix tmp_dfdx (nn, nn);
+  for (int i = 0; i < nn; i++)
+    {
+      tmp_deriv.elem (i) = deriv [i];
+      tmp_state.elem (i) = state [i];
+    }
 
-  DAEFunc::DAEJac tmp_jac;
-  tmp_jac.dfdxdot = &tmp_dfdxdot;
-  tmp_jac.dfdx    = &tmp_dfdx;
-
-  tmp_jac = user_jac (tmp_state, tmp_deriv, time);
-
-  // Fix up the matrix of partial derivatives for dassl.
-
-  tmp_dfdx = tmp_dfdx + cj * tmp_dfdxdot;
+  Matrix tmp_pd = user_jac (tmp_state, tmp_deriv, time, cj);
 
   for (int j = 0; j < nn; j++)
     for (int i = 0; i < nn; i++)
-      pd [nn * j + i] = tmp_dfdx.elem (i, j);
+      pd [nn * j + i] = tmp_pd.elem (i, j);
 
   return 0;
 }
--- a/liboctave/ODESSA.cc
+++ b/liboctave/ODESSA.cc
@@ -67,14 +67,14 @@
 }
 
 typedef int (*odessa_fcn_ptr) (int*, const double&, double*,
-                              double*, double*);
+			       double*, double*);
 
 typedef int (*odessa_jac_ptr) (int*, const double&, double*,
-                              double*, const int&, const int&,
-                              double*, const int&);
+			       double*, const int&, const int&,
+			       double*, const int&);
 
 typedef int (*odessa_dfdp_ptr) (int*, const double&, double*,
-                              double*, double*, const int&);
+				double*, double*, const int&);
 
 
 extern "C"
--- a/src/ChangeLog
+++ b/src/ChangeLog
@@ -1,5 +1,8 @@
 2002-07-16  John W. Eaton  <jwe@bevo.che.wisc.edu>
 
+	* DLD-FUNCTIONS/dassl.cc (dassl_user_jacobian): New function.
+	(Fdassl): Handle Jacobian function.
+
 	* DLD-FUNCTIONS/dasrt.cc: New file.
 	* Makefile.in (DLD_XSRC): Add it to the list.
 
--- a/src/DLD-FUNCTIONS/dassl.cc
+++ b/src/DLD-FUNCTIONS/dassl.cc
@@ -44,6 +44,9 @@
 // Global pointer for user defined function required by dassl.
 static octave_function *dassl_fcn;
 
+// Global pointer for optional user defined jacobian function.
+static octave_function *dassl_jac;
+
 static DASSL_options dassl_opts;
 
 // Is this a recursive call?
@@ -114,6 +117,70 @@
   return retval;
 }
 
+Matrix
+dassl_user_jacobian (const ColumnVector& x, const ColumnVector& xdot,
+		     double t, double cj)
+{
+  Matrix retval;
+
+  int nstates = x.capacity ();
+
+  assert (nstates == xdot.capacity ());
+
+  octave_value_list args;
+
+  args(3) = cj;
+  args(2) = t;
+
+  if (nstates > 1)
+    {
+      Matrix m1 (nstates, 1);
+      Matrix m2 (nstates, 1);
+      for (int i = 0; i < nstates; i++)
+	{
+	  m1 (i, 0) = x (i);
+	  m2 (i, 0) = xdot (i);
+	}
+      octave_value state (m1);
+      octave_value deriv (m2);
+      args(1) = deriv;
+      args(0) = state;
+    }
+  else
+    {
+      double d1 = x (0);
+      double d2 = xdot (0);
+      octave_value state (d1);
+      octave_value deriv (d2);
+      args(1) = deriv;
+      args(0) = state;
+    }
+
+  if (dassl_jac)
+    {
+      octave_value_list tmp = dassl_jac->do_multi_index_op (1, args);
+
+      if (error_state)
+	{
+	  gripe_user_supplied_eval ("dassl");
+	  return retval;
+	}
+
+      int tlen = tmp.length ();
+      if (tlen > 0 && tmp(0).is_defined ())
+	{
+	  retval = tmp(0).matrix_value ();
+
+	  if (error_state || retval.length () == 0)
+	    gripe_user_supplied_eval ("dassl");
+	}
+      else
+	gripe_user_supplied_eval ("dassl");
+    }
+
+  return retval;
+}
+
 #define DASSL_ABORT() \
   do \
     { \
@@ -138,7 +205,7 @@
     } \
   while (0)
 
-DEFUN_DLD (dassl, args, ,
+DEFUN_DLD (dassl, args, nargout,
   "-*- texinfo -*-\n\
 @deftypefn {Loadable Function} {[@var{x}, @var{xdot}] =} dassl (@var{fcn}, @var{x0}, @var{xdot0}, @var{t}, @var{t_crit})\n\
 Return a matrix of states and their first derivatives with respect to\n\
@@ -191,14 +258,48 @@
 
   int nargin = args.length ();
 
-  if (nargin > 3 && nargin < 6)
+  if (nargin > 3 && nargin < 6 && nargout < 5)
     {
-      dassl_fcn = extract_function
-	(args(0), "dassl", "__dassl_fcn__",
-	 "function res = __dassl_fcn__ (x, xdot, t) res = ",
-	 "; endfunction");
+      dassl_fcn = 0;
+      dassl_jac = 0;
+
+      octave_value f_arg = args(0);
+
+      switch (f_arg.rows ())
+	{
+	case 1:
+	  dassl_fcn = extract_function
+	    (f_arg, "dassl", "__dassl_fcn__",
+	     "function res = __dassl_fcn__ (x, xdot, t) res = ",
+	     "; endfunction");
+	  break;
+
+	case 2:
+	  {
+	    string_vector tmp = f_arg.all_strings ();
 
-      if (! dassl_fcn)
+	    if (! error_state)
+	      {
+		dassl_fcn = extract_function
+		  (tmp(0), "dassl", "__dassl_fcn__",
+		   "function res = __dassl_fcn__ (x, xdot, t) res = ",
+		   "; endfunction");
+
+		if (dassl_fcn)
+		  {
+		    dassl_jac = extract_function
+		      (tmp(1), "dassl", "__dassl_jac__",
+		       "function jac = __dassl_jac__ (x, xdot, t, cj) jac = ",
+		       "; endfunction");
+
+		    if (! dassl_jac)
+		      dassl_fcn = 0;
+		  }
+	      }
+	  }
+	}
+
+      if (error_state || ! dassl_fcn)
 	DASSL_ABORT ();
 
       ColumnVector state = ColumnVector (args(1).vector_value ());
@@ -234,7 +335,11 @@
       double tzero = out_times (0);
 
       DAEFunc func (dassl_user_function);
+      if (dassl_jac)
+	func.set_jacobian_function (dassl_user_jacobian);
+
       DASSL dae (state, deriv, tzero, func);
+
       dae.copy (dassl_opts);
 
       Matrix output;
@@ -247,10 +352,8 @@
 
       if (! error_state)
 	{
-	  retval.resize (2);
-
+	  retval(1) = deriv_output;
 	  retval(0) = output;
-	  retval(1) = deriv_output;
 	}
     }
   else
@@ -394,7 +497,7 @@
   "-*- texinfo -*-\n\
 @deftypefn {Loadable Function} {} dassl_options (@var{opt}, @var{val})\n\
 When called with two arguments, this function allows you set options\n\
-parameters for the function @code{lsode}.  Given one argument,\n\
+parameters for the function @code{dassl}.  Given one argument,\n\
 @code{dassl_options} returns the value of the corresponding option.  If\n\
 no arguments are supplied, the names of all the available options and\n\
 their current values are displayed.\n\
--- a/src/DLD-FUNCTIONS/lsode.cc
+++ b/src/DLD-FUNCTIONS/lsode.cc
@@ -201,20 +201,23 @@
 
   if (nargin > 2 && nargin < 5 && nargout < 4)
     {
+      lsode_fcn = 0;
+      lsode_jac = 0;
+
       octave_value f_arg = args(0);
 
       switch (f_arg.rows ())
 	{
 	case 1:
 	  lsode_fcn = extract_function
-	    (args(0), "lsode", "__lsode_fcn__",
+	    (f_arg, "lsode", "__lsode_fcn__",
 	     "function xdot = __lsode_fcn__ (x, t) xdot = ",
 	     "; endfunction");
 	  break;
 
 	case 2:
 	  {
-	    string_vector tmp = args(0).all_strings ();
+	    string_vector tmp = f_arg.all_strings ();
 
 	    if (! error_state)
 	      {