Mercurial > hg > octave-lyh
diff scripts/general/interp1.m @ 9754:4219e5cf773d
improve interp1 and pchip
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Thu, 22 Oct 2009 10:12:54 +0200 |
parents | e9dc2ed2ec0f |
children | 9a1c4fe44af8 |
line wrap: on
line diff
--- a/scripts/general/interp1.m +++ b/scripts/general/interp1.m @@ -126,15 +126,13 @@ ## reshape matrices for convenience x = x(:); - nx = size (x, 1); - if (isvector(y) && size (y, 1) == 1) + nx = rows (x); + if (isvector (y)) y = y(:); endif - ndy = ndims (y); szy = size (y); - ny = szy(1); - nc = prod (szy(2:end)); - y = reshape (y, ny, nc); + y = y(:,:); + [ny, nc] = size (y); szx = size (xi); xi = xi(:); @@ -143,61 +141,42 @@ error ("interp1: table too short"); endif - ## determine which values are out of range and set them to extrap, - ## unless extrap == "extrap" in which case, extrapolate them like we - ## should be doing in the first place. - minx = x(1); - maxx = x(nx); - if (minx > maxx) - tmp = minx; - minx = maxx; - maxx = tmp; - endif - if (method(1) == "*") - dx = x(2) - x(1); + ## check whether x is sorted; sort if not. + if (! issorted (x)) + [x, p] = sort (x); + y = y(p,:); endif - if (! pp) - if (ischar (extrap) && strcmp (extrap, "extrap")) - range = 1:size (xi, 1); - yi = zeros (size (xi, 1), size (y, 2)); - else - range = find (xi >= minx & xi <= maxx); - yi = extrap * ones (size (xi, 1), size (y, 2)); - if (isempty (range)) - if (! isvector (y) && length (szx) == 2 - && (szx(1) == 1 || szx(2) == 1)) - if (szx(1) == 1) - yi = reshape (yi, [szx(2), szy(2:end)]); - else - yi = reshape (yi, [szx(1), szy(2:end)]); - endif - else - yi = reshape (yi, [szx, szy(2:end)]); - endif - return; - endif - xi = xi(range); + starmethod = method(1) == "*"; + + if (starmethod) + dx = x(2) - x(1); + else + if (any (x(1:nx-1) == x(2:nx))) + error ("interp1: table must be strictly monotonic"); endif endif - if (strcmp (method, "nearest")) + ## Proceed with interpolating by all methods. + + switch (method) + case "nearest" if (pp) yi = mkpp ([x(1); (x(1:end-1)+x(2:end))/2; x(end)], y, szy(2:end)); else idx = lookup (0.5*(x(1:nx-1)+x(2:nx)), xi) + 1; - yi(range,:) = y(idx,:); + yi = y(idx,:); endif - elseif (strcmp (method, "*nearest")) + case "*nearest" if (pp) yi = mkpp ([x(1); x(1)+[0.5:(ny-1)]'*dx; x(nx)], y, szy(2:end)); else idx = max (1, min (ny, floor((xi-x(1))/dx+1.5))); - yi(range,:) = y(idx,:); + yi = y(idx,:); endif - elseif (strcmp (method, "linear")) - dy = y(2:ny,:) - y(1:ny-1,:); - dx = x(2:nx) - x(1:nx-1); + case "linear" + dy = diff (y); + dx = diff (x); if (pp) yi = mkpp (x, [dy./dx, y(1:end-1)], szy(2:end)); else @@ -205,24 +184,23 @@ idx = lookup (x, xi, "lr"); ## use the endpoints of the interval to define a line s = (xi - x(idx))./dx(idx); - yi(range,:) = s(:,ones(1,nc)).*dy(idx,:) + y(idx,:); + yi = bsxfun (@times, s, dy(idx,:)) + y(idx,:); endif - elseif (strcmp (method, "*linear")) + case "*linear" + dy = diff (y); if (pp) - dy = [y(2:ny,:) - y(1:ny-1,:)]; yi = mkpp (x(1) + [0:ny-1]*dx, [dy./dx, y(1:end-1)], szy(2:end)); else ## find the interval containing the test point t = (xi - x(1))/dx + 1; - idx = max(1,min(ny,floor(t))); + idx = max (1, min (ny - 1, floor (t))); ## use the endpoints of the interval to define a line - dy = [y(2:ny,:) - y(1:ny-1,:); y(ny,:) - y(ny-1,:)]; s = t - idx; - yi(range,:) = s(:,ones(1,nc)).*dy(idx,:) + y(idx,:); + yi = bsxfun (@times, s, dy(idx,:)) + y(idx,:); endif - elseif (strcmp (method, "pchip") || strcmp (method, "*pchip")) - if (nx == 2 || method(1) == "*") + case {"pchip", "*pchip"} + if (nx == 2 || starmethod) x = linspace (x(1), x(nx), ny); endif ## Note that pchip's arguments are transposed relative to interp1 @@ -230,69 +208,67 @@ yi = pchip (x.', y.'); yi.d = szy(2:end); else - yi(range,:) = pchip (x.', y.', xi.').'; - endif - - elseif (strcmp (method, "cubic") || (strcmp (method, "*cubic") && pp)) - ## FIXME Is there a better way to treat pp return return and *cubic - if (method(1) == "*") - x = linspace (x(1), x(nx), ny).'; - nx = ny; + yi = pchip (x.', y.', xi.').'; endif - if (nx < 4 || ny < 4) - error ("interp1: table too short"); - endif - idx = lookup (x(2:nx-1), xi, "lr"); - - ## Construct cubic equations for each interval using divided - ## differences (computation of c and d don't use divided differences - ## but instead solve 2 equations for 2 unknowns). Perhaps - ## reformulating this as a lagrange polynomial would be more efficient. - i = 1:nx-3; - J = ones (1, nc); - dx = diff (x); - dx2 = x(i+1).^2 - x(i).^2; - dx3 = x(i+1).^3 - x(i).^3; - a = diff (y, 3)./dx(i,J).^3/6; - b = (diff (y(1:nx-1,:), 2)./dx(i,J).^2 - 6*a.*x(i+1,J))/2; - c = (diff (y(1:nx-2,:), 1) - a.*dx3(:,J) - b.*dx2(:,J))./dx(i,J); - d = y(i,:) - ((a.*x(i,J) + b).*x(i,J) + c).*x(i,J); - - if (pp) - xs = [x(1);x(3:nx-2)]; - yi = mkpp ([x(1);x(3:nx-2);x(nx)], - [a(:), (b(:) + 3.*xs(:,J).*a(:)), ... - (c(:) + 2.*xs(:,J).*b(:) + 3.*xs(:,J)(:).^2.*a(:)), ... - (d(:) + xs(:,J).*c(:) + xs(:,J).^2.*b(:) + ... - xs(:,J).^3.*a(:))], szy(2:end)); - else - yi(range,:) = ((a(idx,:).*xi(:,J) + b(idx,:)).*xi(:,J) ... - + c(idx,:)).*xi(:,J) + d(idx,:); - endif - elseif (strcmp (method, "*cubic")) + case {"cubic", "*cubic"} if (nx < 4 || ny < 4) error ("interp1: table too short"); endif - ## From: Miloje Makivic - ## http://www.npac.syr.edu/projects/nasa/MILOJE/final/node36.html - t = (xi - x(1))/dx + 1; - idx = max (min (floor (t), ny-2), 2); - t = t - idx; - t2 = t.*t; - tp = 1 - 0.5*t; - a = (1 - t2).*tp; - b = (t2 + t).*tp; - c = (t2 - t).*tp/3; - d = (t2 - 1).*t/6; - J = ones (1, nc); + ## FIXME Is there a better way to treat pp return and *cubic + if (starmethod && ! pp) + ## From: Miloje Makivic + ## http://www.npac.syr.edu/projects/nasa/MILOJE/final/node36.html + t = (xi - x(1))/dx + 1; + idx = max (min (floor (t), ny-2), 2); + t = t - idx; + t2 = t.*t; + tp = 1 - 0.5*t; + a = (1 - t2).*tp; + b = (t2 + t).*tp; + c = (t2 - t).*tp/3; + d = (t2 - 1).*t/6; + J = ones (1, nc); + + yi = a(:,J) .* y(idx,:) + b(:,J) .* y(idx+1,:) ... + + c(:,J) .* y(idx-1,:) + d(:,J) .* y(idx+2,:); + else + if (starmethod) + x = linspace (x(1), x(nx), ny).'; + nx = ny; + endif + + idx = lookup (x(2:nx-1), xi, "lr"); - yi(range,:) = a(:,J) .* y(idx,:) + b(:,J) .* y(idx+1,:) ... - + c(:,J) .* y(idx-1,:) + d(:,J) .* y(idx+2,:); + ## Construct cubic equations for each interval using divided + ## differences (computation of c and d don't use divided differences + ## but instead solve 2 equations for 2 unknowns). Perhaps + ## reformulating this as a lagrange polynomial would be more efficient. + i = 1:nx-3; + J = ones (1, nc); + dx = diff (x); + dx2 = x(i+1).^2 - x(i).^2; + dx3 = x(i+1).^3 - x(i).^3; + a = diff (y, 3)./dx(i,J).^3/6; + b = (diff (y(1:nx-1,:), 2)./dx(i,J).^2 - 6*a.*x(i+1,J))/2; + c = (diff (y(1:nx-2,:), 1) - a.*dx3(:,J) - b.*dx2(:,J))./dx(i,J); + d = y(i,:) - ((a.*x(i,J) + b).*x(i,J) + c).*x(i,J); - elseif (strcmp (method, "spline") || strcmp (method, "*spline")) - if (nx == 2 || method(1) == "*") + if (pp) + xs = [x(1);x(3:nx-2)]; + yi = mkpp ([x(1);x(3:nx-2);x(nx)], + [a(:), (b(:) + 3.*xs(:,J).*a(:)), ... + (c(:) + 2.*xs(:,J).*b(:) + 3.*xs(:,J)(:).^2.*a(:)), ... + (d(:) + xs(:,J).*c(:) + xs(:,J).^2.*b(:) + ... + xs(:,J).^3.*a(:))], szy(2:end)); + else + yi = ((a(idx,:).*xi(:,J) + b(idx,:)).*xi(:,J) ... + + c(idx,:)).*xi(:,J) + d(idx,:); + endif + endif + case {"spline", "*spline"} + if (nx == 2 || starmethod) x = linspace(x(1), x(nx), ny); endif ## Note that spline's arguments are transposed relative to interp1 @@ -300,13 +276,23 @@ yi = spline (x.', y.'); yi.d = szy(2:end); else - yi(range,:) = spline (x.', y.', xi.').'; + yi = spline (x.', y.', xi.').'; endif - else + otherwise error ("interp1: invalid method '%s'", method); - endif + endswitch if (! pp) + if (! ischar (extrap)) + ## determine which values are out of range and set them to extrap, + ## unless extrap == "extrap". + minx = min (x(1), x(nx)); + maxx = max (x(1), x(nx)); + + outliers = xi < minx | ! (xi <= maxx); # this catches even NaNs + yi(outliers, :) = extrap; + endif + if (! isvector (y) && length (szx) == 2 && (szx(1) == 1 || szx(2) == 1)) if (szx(1) == 1) yi = reshape (yi, [szx(2), szy(2:end)]);