changeset 8398:d95282fa0579

allow element assignment to diagonal matrices
author Jaroslav Hajek <highegg@gmail.com>
date Thu, 11 Dec 2008 11:04:00 +0100
parents 4780279e8094
children c1bada868690
files src/ChangeLog src/ov-base-diag.cc src/ov-base-diag.h src/ov-cx-diag.cc src/ov-cx-diag.h src/ov-flt-cx-diag.cc src/ov-flt-cx-diag.h src/ov-flt-re-diag.cc src/ov-flt-re-diag.h src/ov-re-diag.cc src/ov-re-diag.h
diffstat 11 files changed, 150 insertions(+), 1 deletions(-) [+]
line wrap: on
line diff
--- a/src/ChangeLog
+++ b/src/ChangeLog
@@ -1,3 +1,22 @@
+2008-12-11  Jaroslav Hajek  <highegg@gmail.com>
+
+	* ov-base-diag.cc (octave_base_diag<DMT,MT>::subsasgn): New method.
+	* ov-base-diag.h (octave_base_diag<DMT,MT>::subsasgn): Declare it.
+	(octave_base_diag<DMT,MT>::chk_valid_scalar): New method decl.
+
+	* ov-re-diag.cc (octave_diag_matrix::chk_valid_scalar): New method
+	override.
+	* ov-re-diag.h: Declare it.
+	* ov-flt-re-diag.cc (octave_float_diag_matrix::chk_valid_scalar): New
+	method override.
+	* ov-flt-re-diag.h: Declare it.
+	* ov-cx-diag.cc (octave_complex_diag_matrix::chk_valid_scalar): New 
+	method override.
+	* ov-cx-diag.h: Declare it.
+	* ov-flt-cx-diag.cc (octave_float_complex_diag_matrix::chk_valid_scalar): 
+	New method override.
+	* ov-flt-cx-diag.h: Declare it.
+
 2008-12-10  Jaroslav Hajek  <highegg@gmail.com>
 
 	* DLD-FUNCTIONS/expm.cc: Remove.
--- a/src/ov-base-diag.cc
+++ b/src/ov-base-diag.cc
@@ -128,6 +128,75 @@
 }
 
 template <class DMT, class MT>
+octave_value 
+octave_base_diag<DMT, MT>::subsasgn (const std::string& type,
+                                     const std::list<octave_value_list>& idx,
+                                     const octave_value& rhs)
+{
+  octave_value retval;
+
+  switch (type[0])
+    {
+    case '(':
+      {
+	if (type.length () == 1)
+          {
+            octave_value_list jdx = idx.front ();
+            // Check for a simple element assignment. That means, if D is a diagonal matrix,
+            // `D(i,i) = x' will not destroy its diagonality (provided i is a valid index).
+            if (jdx.length () == 2 && jdx(0).is_scalar_type () && jdx(1).is_scalar_type ())
+              {
+                typename DMT::element_type val;
+                idx_vector i0 = jdx(0).index_vector (), i1 = jdx(1).index_vector ();
+                if (! error_state  && i0(0) == i1(0) 
+                    && i0(0) < matrix.rows () && i1(0) < matrix.cols ()
+                    && chk_valid_scalar (rhs, val))
+                  {
+                    matrix (i0(0), i1(0)) = val;                    
+                    retval = this;
+                    this->count++;
+                    // invalidate cache
+                    dense_cache = octave_value ();
+                  }
+              }
+
+            if (! error_state && ! retval.is_defined ())
+              retval = numeric_assign (type, idx, rhs);
+          }
+	else
+	  {
+	    std::string nm = type_name ();
+	    error ("in indexed assignment of %s, last lhs index must be ()",
+		   nm.c_str ());
+	  }
+      }
+      break;
+
+    case '{':
+    case '.':
+      {
+	if (is_empty ())
+	  {
+	    octave_value tmp = octave_value::empty_conv (type, rhs);
+
+	    retval = tmp.subsasgn (type, idx, rhs);
+	  }
+	else
+	  {
+	    std::string nm = type_name ();
+	    error ("%s cannot be indexed with %c", nm.c_str (), type[0]);
+	  }
+      }
+      break;
+
+    default:
+      panic_impossible ();
+    }
+
+  return retval;
+}
+
+template <class DMT, class MT>
 octave_value
 octave_base_diag<DMT, MT>::resize (const dim_vector& dv, bool fill) const
 {
--- a/src/ov-base-diag.h
+++ b/src/ov-base-diag.h
@@ -72,6 +72,10 @@
   octave_value do_index_op (const octave_value_list& idx,
 			    bool resize_ok = false);
 
+  octave_value subsasgn (const std::string& type,
+			 const std::list<octave_value_list>& idx,
+			 const octave_value& rhs);
+
   dim_vector dims (void) const { return matrix.dims (); }
 
   octave_idx_type nnz (void) const { return to_dense ().nnz (); }
@@ -244,7 +248,10 @@
 
   DMT matrix;
 
-  octave_value to_dense () const;
+  octave_value to_dense (void) const;
+
+  virtual bool chk_valid_scalar (const octave_value&, 
+                                 typename DMT::element_type&) const = 0;
 
 private:
 
--- a/src/ov-cx-diag.cc
+++ b/src/ov-cx-diag.cc
@@ -220,3 +220,12 @@
   return true;
 }
 
+bool 
+octave_complex_diag_matrix::chk_valid_scalar (const octave_value& val, 
+                                              Complex& x) const
+{
+  bool retval = val.is_complex_scalar () || val.is_real_scalar ();
+  if (retval)
+    x = val.complex_value ();
+  return retval;
+}
--- a/src/ov-cx-diag.h
+++ b/src/ov-cx-diag.h
@@ -84,6 +84,10 @@
   octave_value real (void) const;
 
 private:
+
+  bool chk_valid_scalar (const octave_value&, 
+                         Complex&) const;
+
   DECLARE_OCTAVE_ALLOCATOR
 
   DECLARE_OV_TYPEID_FUNCTIONS_AND_DATA
--- a/src/ov-flt-cx-diag.cc
+++ b/src/ov-flt-cx-diag.cc
@@ -194,3 +194,13 @@
 
   return true;
 }
+
+bool 
+octave_float_complex_diag_matrix::chk_valid_scalar (const octave_value& val, 
+                                                    FloatComplex& x) const
+{
+  bool retval = val.is_complex_scalar () || val.is_real_scalar ();
+  if (retval)
+    x = val.float_complex_value ();
+  return retval;
+}
--- a/src/ov-flt-cx-diag.h
+++ b/src/ov-flt-cx-diag.h
@@ -82,6 +82,10 @@
   octave_value real (void) const;
 
 private:
+
+  bool chk_valid_scalar (const octave_value&, 
+                         FloatComplex&) const;
+
   DECLARE_OCTAVE_ALLOCATOR
 
   DECLARE_OV_TYPEID_FUNCTIONS_AND_DATA
--- a/src/ov-flt-re-diag.cc
+++ b/src/ov-flt-re-diag.cc
@@ -163,3 +163,13 @@
 
   return true;
 }
+
+bool 
+octave_float_diag_matrix::chk_valid_scalar (const octave_value& val, 
+                                            float& x) const
+{
+  bool retval = val.is_real_scalar ();
+  if (retval)
+    x = val.float_value ();
+  return retval;
+}
--- a/src/ov-flt-re-diag.h
+++ b/src/ov-flt-re-diag.h
@@ -82,6 +82,10 @@
   octave_value real (void) const;
 
 private:
+
+  bool chk_valid_scalar (const octave_value&, 
+                         float&) const;
+
   DECLARE_OCTAVE_ALLOCATOR
 
   DECLARE_OV_TYPEID_FUNCTIONS_AND_DATA
--- a/src/ov-re-diag.cc
+++ b/src/ov-re-diag.cc
@@ -189,3 +189,12 @@
   return true;
 }
 
+bool 
+octave_diag_matrix::chk_valid_scalar (const octave_value& val, 
+                                      double& x) const
+{
+  bool retval = val.is_real_scalar ();
+  if (retval)
+    x = val.double_value ();
+  return retval;
+}
--- a/src/ov-re-diag.h
+++ b/src/ov-re-diag.h
@@ -84,6 +84,10 @@
   octave_value real (void) const;
 
 private:
+
+  bool chk_valid_scalar (const octave_value&, 
+                         double&) const;
+
   DECLARE_OCTAVE_ALLOCATOR
 
   DECLARE_OV_TYPEID_FUNCTIONS_AND_DATA