comparison src/data.cc @ 10716:f7f26094021b

improve cat code design in data.cc, make horzcat/vertcat more Matlab compatible
author Jaroslav Hajek <highegg@gmail.com>
date Mon, 21 Jun 2010 15:48:56 +0200
parents fbd7843974fa
children f3892d8eea9f
comparison
equal deleted inserted replaced
10715:53253f796351 10716:f7f26094021b
1374 single_type_concat (Array<T>& result, 1374 single_type_concat (Array<T>& result,
1375 const octave_value_list& args, 1375 const octave_value_list& args,
1376 int dim) 1376 int dim)
1377 { 1377 {
1378 int n_args = args.length (); 1378 int n_args = args.length ();
1379 OCTAVE_LOCAL_BUFFER (Array<T>, array_list, n_args - 1); 1379 OCTAVE_LOCAL_BUFFER (Array<T>, array_list, n_args);
1380 1380
1381 for (int j = 1; j < n_args && ! error_state; j++) 1381 for (int j = 0; j < n_args && ! error_state; j++)
1382 { 1382 {
1383 octave_quit (); 1383 octave_quit ();
1384 1384
1385 array_list[j-1] = octave_value_extract<TYPE> (args(j)); 1385 array_list[j] = octave_value_extract<TYPE> (args(j));
1386 } 1386 }
1387 1387
1388 if (! error_state) 1388 if (! error_state)
1389 result = Array<T>::cat (dim, n_args-1, array_list); 1389 result = Array<T>::cat (dim, n_args, array_list);
1390 } 1390 }
1391 1391
1392 template <class TYPE, class T> 1392 template <class TYPE, class T>
1393 static void 1393 static void
1394 single_type_concat (Sparse<T>& result, 1394 single_type_concat (Sparse<T>& result,
1395 const octave_value_list& args, 1395 const octave_value_list& args,
1396 int dim) 1396 int dim)
1397 { 1397 {
1398 int n_args = args.length (); 1398 int n_args = args.length ();
1399 OCTAVE_LOCAL_BUFFER (Sparse<T>, sparse_list, n_args-1); 1399 OCTAVE_LOCAL_BUFFER (Sparse<T>, sparse_list, n_args);
1400 1400
1401 for (int j = 1; j < n_args && ! error_state; j++) 1401 for (int j = 0; j < n_args && ! error_state; j++)
1402 { 1402 {
1403 octave_quit (); 1403 octave_quit ();
1404 1404
1405 sparse_list[j-1] = octave_value_extract<TYPE> (args(j)); 1405 sparse_list[j] = octave_value_extract<TYPE> (args(j));
1406 } 1406 }
1407 1407
1408 if (! error_state) 1408 if (! error_state)
1409 result = Sparse<T>::cat (dim, n_args-1, sparse_list); 1409 result = Sparse<T>::cat (dim, n_args, sparse_list);
1410 } 1410 }
1411 1411
1412 // Dispatcher. 1412 // Dispatcher.
1413 template<class TYPE> 1413 template<class TYPE>
1414 static TYPE 1414 static TYPE
1420 1420
1421 return result; 1421 return result;
1422 } 1422 }
1423 1423
1424 static octave_value 1424 static octave_value
1425 do_cat (const octave_value_list& args, std::string fname) 1425 do_cat (const octave_value_list& args, int dim, std::string fname)
1426 { 1426 {
1427 octave_value retval; 1427 octave_value retval;
1428 1428
1429 int n_args = args.length (); 1429 int n_args = args.length ();
1430 1430
1431 if (n_args == 1) 1431 if (n_args == 0)
1432 retval = Matrix (); 1432 retval = Matrix ();
1433 else if (n_args == 2) 1433 else if (n_args == 1)
1434 retval = args(1); 1434 retval = args(0);
1435 else if (n_args > 2) 1435 else if (n_args > 1)
1436 { 1436 {
1437 octave_idx_type dim = args(0).int_value () - 1; 1437
1438 1438 std::string result_type = args(0).class_name ();
1439 if (error_state) 1439
1440 { 1440 bool all_sq_strings_p = args(0).is_sq_string ();
1441 error ("cat: expecting first argument to be a integer"); 1441 bool all_dq_strings_p = args(0).is_dq_string ();
1442 return retval; 1442 bool all_real_p = args(0).is_real_type ();
1443 } 1443 bool any_sparse_p = args(0).is_sparse_type();
1444 1444
1445 if (dim >= 0) 1445 for (int i = 1; i < args.length (); i++)
1446 { 1446 {
1447 1447 result_type =
1448 std::string result_type = args(1).class_name (); 1448 get_concat_class (result_type, args(i).class_name ());
1449 1449
1450 bool all_sq_strings_p = args(1).is_sq_string (); 1450 if (all_sq_strings_p && ! args(i).is_sq_string ())
1451 bool all_dq_strings_p = args(1).is_dq_string (); 1451 all_sq_strings_p = false;
1452 bool all_real_p = args(1).is_real_type (); 1452 if (all_dq_strings_p && ! args(i).is_dq_string ())
1453 bool any_sparse_p = args(1).is_sparse_type(); 1453 all_dq_strings_p = false;
1454 1454 if (all_real_p && ! args(i).is_real_type ())
1455 for (int i = 2; i < args.length (); i++) 1455 all_real_p = false;
1456 if (!any_sparse_p && args(i).is_sparse_type ())
1457 any_sparse_p = true;
1458 }
1459
1460 if (result_type == "double")
1461 {
1462 if (any_sparse_p)
1463 {
1464 if (all_real_p)
1465 retval = do_single_type_concat<SparseMatrix> (args, dim);
1466 else
1467 retval = do_single_type_concat<SparseComplexMatrix> (args, dim);
1468 }
1469 else
1456 { 1470 {
1457 result_type = 1471 if (all_real_p)
1458 get_concat_class (result_type, args(i).class_name ()); 1472 retval = do_single_type_concat<NDArray> (args, dim);
1459 1473 else
1460 if (all_sq_strings_p && ! args(i).is_sq_string ()) 1474 retval = do_single_type_concat<ComplexNDArray> (args, dim);
1461 all_sq_strings_p = false;
1462 if (all_dq_strings_p && ! args(i).is_dq_string ())
1463 all_dq_strings_p = false;
1464 if (all_real_p && ! args(i).is_real_type ())
1465 all_real_p = false;
1466 if (!any_sparse_p && args(i).is_sparse_type ())
1467 any_sparse_p = true;
1468 } 1475 }
1469 1476 }
1470 if (result_type == "double") 1477 else if (result_type == "single")
1478 {
1479 if (all_real_p)
1480 retval = do_single_type_concat<FloatNDArray> (args, dim);
1481 else
1482 retval = do_single_type_concat<FloatComplexNDArray> (args, dim);
1483 }
1484 else if (result_type == "char")
1485 {
1486 char type = all_dq_strings_p ? '"' : '\'';
1487
1488 maybe_warn_string_concat (all_dq_strings_p, all_sq_strings_p);
1489
1490 charNDArray result = do_single_type_concat<charNDArray> (args, dim);
1491
1492 retval = octave_value (result, type);
1493 }
1494 else if (result_type == "logical")
1495 {
1496 if (any_sparse_p)
1497 retval = do_single_type_concat<SparseBoolMatrix> (args, dim);
1498 else
1499 retval = do_single_type_concat<boolNDArray> (args, dim);
1500 }
1501 else if (result_type == "int8")
1502 retval = do_single_type_concat<int8NDArray> (args, dim);
1503 else if (result_type == "int16")
1504 retval = do_single_type_concat<int16NDArray> (args, dim);
1505 else if (result_type == "int32")
1506 retval = do_single_type_concat<int32NDArray> (args, dim);
1507 else if (result_type == "int64")
1508 retval = do_single_type_concat<int64NDArray> (args, dim);
1509 else if (result_type == "uint8")
1510 retval = do_single_type_concat<uint8NDArray> (args, dim);
1511 else if (result_type == "uint16")
1512 retval = do_single_type_concat<uint16NDArray> (args, dim);
1513 else if (result_type == "uint32")
1514 retval = do_single_type_concat<uint32NDArray> (args, dim);
1515 else if (result_type == "uint64")
1516 retval = do_single_type_concat<uint64NDArray> (args, dim);
1517 else
1518 {
1519 dim_vector dv = args(0).dims ();
1520
1521 // Default concatenation.
1522 bool (dim_vector::*concat_rule) (const dim_vector&, int) = &dim_vector::concat;
1523
1524 if (dim == -1 || dim == -2)
1471 { 1525 {
1472 if (any_sparse_p) 1526 concat_rule = &dim_vector::hvcat;
1473 { 1527 dim = -dim - 1;
1474 if (all_real_p) 1528 }
1475 retval = do_single_type_concat<SparseMatrix> (args, dim); 1529
1476 else 1530 for (int i = 1; i < args.length (); i++)
1477 retval = do_single_type_concat<SparseComplexMatrix> (args, dim); 1531 {
1532 if (! (dv.*concat_rule) (args(i).dims (), dim))
1533 {
1534 // Dimensions do not match.
1535 error ("cat: dimension mismatch");
1536 return retval;
1537 }
1538 }
1539
1540 // The lines below might seem crazy, since we take a copy
1541 // of the first argument, resize it to be empty and then resize
1542 // it to be full. This is done since it means that there is no
1543 // recopying of data, as would happen if we used a single resize.
1544 // It should be noted that resize operation is also significantly
1545 // slower than the do_cat_op function, so it makes sense to have
1546 // an empty matrix and copy all data.
1547 //
1548 // We might also start with a empty octave_value using
1549 // tmp = octave_value_typeinfo::lookup_type
1550 // (args(1).type_name());
1551 // and then directly resize. However, for some types there might
1552 // be some additional setup needed, and so this should be avoided.
1553
1554 octave_value tmp = args (0);
1555 tmp = tmp.resize (dim_vector (0,0)).resize (dv);
1556
1557 if (error_state)
1558 return retval;
1559
1560 int dv_len = dv.length ();
1561 Array<octave_idx_type> ra_idx (dv_len, 1, 0);
1562
1563 for (int j = 0; j < n_args; j++)
1564 {
1565 // Can't fast return here to skip empty matrices as something
1566 // like cat(1,[],single([])) must return an empty matrix of
1567 // the right type.
1568 tmp = do_cat_op (tmp, args (j), ra_idx);
1569
1570 if (error_state)
1571 return retval;
1572
1573 dim_vector dv_tmp = args (j).dims ();
1574
1575 if (dim >= dv_len)
1576 {
1577 if (j > 1)
1578 error ("%s: indexing error", fname.c_str ());
1579 break;
1478 } 1580 }
1479 else 1581 else
1480 { 1582 ra_idx (dim) += (dim < dv_tmp.length () ?
1481 if (all_real_p) 1583 dv_tmp (dim) : 1);
1482 retval = do_single_type_concat<NDArray> (args, dim);
1483 else
1484 retval = do_single_type_concat<ComplexNDArray> (args, dim);
1485 }
1486 } 1584 }
1487 else if (result_type == "single") 1585 retval = tmp;
1488 { 1586 }
1489 if (all_real_p)
1490 retval = do_single_type_concat<FloatNDArray> (args, dim);
1491 else
1492 retval = do_single_type_concat<FloatComplexNDArray> (args, dim);
1493 }
1494 else if (result_type == "char")
1495 {
1496 char type = all_dq_strings_p ? '"' : '\'';
1497
1498 maybe_warn_string_concat (all_dq_strings_p, all_sq_strings_p);
1499
1500 charNDArray result = do_single_type_concat<charNDArray> (args, dim);
1501
1502 retval = octave_value (result, type);
1503 }
1504 else if (result_type == "logical")
1505 {
1506 if (any_sparse_p)
1507 retval = do_single_type_concat<SparseBoolMatrix> (args, dim);
1508 else
1509 retval = do_single_type_concat<boolNDArray> (args, dim);
1510 }
1511 else if (result_type == "int8")
1512 retval = do_single_type_concat<int8NDArray> (args, dim);
1513 else if (result_type == "int16")
1514 retval = do_single_type_concat<int16NDArray> (args, dim);
1515 else if (result_type == "int32")
1516 retval = do_single_type_concat<int32NDArray> (args, dim);
1517 else if (result_type == "int64")
1518 retval = do_single_type_concat<int64NDArray> (args, dim);
1519 else if (result_type == "uint8")
1520 retval = do_single_type_concat<uint8NDArray> (args, dim);
1521 else if (result_type == "uint16")
1522 retval = do_single_type_concat<uint16NDArray> (args, dim);
1523 else if (result_type == "uint32")
1524 retval = do_single_type_concat<uint32NDArray> (args, dim);
1525 else if (result_type == "uint64")
1526 retval = do_single_type_concat<uint64NDArray> (args, dim);
1527 else
1528 {
1529 dim_vector dv = args(1).dims ();
1530
1531 for (int i = 2; i < args.length (); i++)
1532 {
1533 if (! dv.concat (args(i).dims (), dim))
1534 {
1535 // Dimensions do not match.
1536 error ("cat: dimension mismatch");
1537 return retval;
1538 }
1539 }
1540
1541 // The lines below might seem crazy, since we take a copy
1542 // of the first argument, resize it to be empty and then resize
1543 // it to be full. This is done since it means that there is no
1544 // recopying of data, as would happen if we used a single resize.
1545 // It should be noted that resize operation is also significantly
1546 // slower than the do_cat_op function, so it makes sense to have
1547 // an empty matrix and copy all data.
1548 //
1549 // We might also start with a empty octave_value using
1550 // tmp = octave_value_typeinfo::lookup_type
1551 // (args(1).type_name());
1552 // and then directly resize. However, for some types there might
1553 // be some additional setup needed, and so this should be avoided.
1554
1555 octave_value tmp = args (1);
1556 tmp = tmp.resize (dim_vector (0,0)).resize (dv);
1557
1558 if (error_state)
1559 return retval;
1560
1561 int dv_len = dv.length ();
1562 Array<octave_idx_type> ra_idx (dv_len, 1, 0);
1563
1564 for (int j = 1; j < n_args; j++)
1565 {
1566 // Can't fast return here to skip empty matrices as something
1567 // like cat(1,[],single([])) must return an empty matrix of
1568 // the right type.
1569 tmp = do_cat_op (tmp, args (j), ra_idx);
1570
1571 if (error_state)
1572 return retval;
1573
1574 dim_vector dv_tmp = args (j).dims ();
1575
1576 if (dim >= dv_len)
1577 {
1578 if (j > 1)
1579 error ("%s: indexing error", fname.c_str ());
1580 break;
1581 }
1582 else
1583 ra_idx (dim) += (dim < dv_tmp.length () ?
1584 dv_tmp (dim) : 1);
1585 }
1586 retval = tmp;
1587 }
1588 }
1589 else
1590 error ("%s: invalid dimension argument", fname.c_str ());
1591 } 1587 }
1592 else 1588 else
1593 print_usage (); 1589 print_usage ();
1594 1590
1595 return retval; 1591 return retval;
1601 Return the horizontal concatenation of N-d array objects, @var{array1},\n\ 1597 Return the horizontal concatenation of N-d array objects, @var{array1},\n\
1602 @var{array2}, @dots{}, @var{arrayN} along dimension 2.\n\ 1598 @var{array2}, @dots{}, @var{arrayN} along dimension 2.\n\
1603 @seealso{cat, vertcat}\n\ 1599 @seealso{cat, vertcat}\n\
1604 @end deftypefn") 1600 @end deftypefn")
1605 { 1601 {
1606 octave_value_list args_tmp = args; 1602 return do_cat (args, -2, "horzcat");
1607
1608 int dim = 2;
1609
1610 octave_value d (dim);
1611
1612 args_tmp.prepend (d);
1613
1614 return do_cat (args_tmp, "horzcat");
1615 } 1603 }
1616 1604
1617 DEFUN (vertcat, args, , 1605 DEFUN (vertcat, args, ,
1618 "-*- texinfo -*-\n\ 1606 "-*- texinfo -*-\n\
1619 @deftypefn {Built-in Function} {} vertcat (@var{array1}, @var{array2}, @dots{}, @var{arrayN})\n\ 1607 @deftypefn {Built-in Function} {} vertcat (@var{array1}, @var{array2}, @dots{}, @var{arrayN})\n\
1620 Return the vertical concatenation of N-d array objects, @var{array1},\n\ 1608 Return the vertical concatenation of N-d array objects, @var{array1},\n\
1621 @var{array2}, @dots{}, @var{arrayN} along dimension 1.\n\ 1609 @var{array2}, @dots{}, @var{arrayN} along dimension 1.\n\
1622 @seealso{cat, horzcat}\n\ 1610 @seealso{cat, horzcat}\n\
1623 @end deftypefn") 1611 @end deftypefn")
1624 { 1612 {
1625 octave_value_list args_tmp = args; 1613 return do_cat (args, -1, "vertcat");
1626
1627 int dim = 1;
1628
1629 octave_value d (dim);
1630
1631 args_tmp.prepend (d);
1632
1633 return do_cat (args_tmp, "vertcat");
1634 } 1614 }
1635 1615
1636 DEFUN (cat, args, , 1616 DEFUN (cat, args, ,
1637 "-*- texinfo -*-\n\ 1617 "-*- texinfo -*-\n\
1638 @deftypefn {Built-in Function} {} cat (@var{dim}, @var{array1}, @var{array2}, @dots{}, @var{arrayN})\n\ 1618 @deftypefn {Built-in Function} {} cat (@var{dim}, @var{array1}, @var{array2}, @dots{}, @var{arrayN})\n\
1679 @end group\n\ 1659 @end group\n\
1680 @end example\n\ 1660 @end example\n\
1681 @seealso{horzcat, vertcat}\n\ 1661 @seealso{horzcat, vertcat}\n\
1682 @end deftypefn") 1662 @end deftypefn")
1683 { 1663 {
1684 return do_cat (args, "cat"); 1664 octave_value retval;
1665
1666 if (args.length () > 0)
1667 {
1668 int dim = args(0).int_value () - 1;
1669
1670 if (! error_state)
1671 {
1672 if (dim >= 0)
1673 retval = do_cat (args.slice (1, args.length () - 1), dim, "cat");
1674 else
1675 error ("cat: invalid dimension specified");
1676 }
1677 else
1678 error ("cat: expecting first argument to be a integer");
1679 }
1680 else
1681 print_usage ();
1682
1683 return retval;
1685 } 1684 }
1686 1685
1687 /* 1686 /*
1688 1687
1689 %!function ret = testcat (t1, t2, tr, cmplx) 1688 %!function ret = testcat (t1, t2, tr, cmplx)