2024-08-17 18:19:10 -04:00

136 lines
4.0 KiB
Python

import torch
def B_batch(x, grid, k=0, extend=True, device='cpu'):
'''
evaludate x on B-spline bases
Args:
-----
x : 2D torch.tensor
inputs, shape (number of splines, number of samples)
grid : 2D torch.tensor
grids, shape (number of splines, number of grid points)
k : int
the piecewise polynomial order of splines.
extend : bool
If True, k points are extended on both ends. If False, no extension (zero boundary condition). Default: True
device : str
devicde
Returns:
--------
spline values : 3D torch.tensor
shape (batch, in_dim, G+k). G: the number of grid intervals, k: spline order.
Example
-------
>>> from kan.spline import B_batch
>>> x = torch.rand(100,2)
>>> grid = torch.linspace(-1,1,steps=11)[None, :].expand(2, 11)
>>> B_batch(x, grid, k=3).shape
'''
x = x.unsqueeze(dim=2)
grid = grid.unsqueeze(dim=0)
if k == 0:
value = (x >= grid[:, :, :-1]) * (x < grid[:, :, 1:])
else:
B_km1 = B_batch(x[:,:,0], grid=grid[0], k=k - 1)
value = (x - grid[:, :, :-(k + 1)]) / (grid[:, :, k:-1] - grid[:, :, :-(k + 1)]) * B_km1[:, :, :-1] + (
grid[:, :, k + 1:] - x) / (grid[:, :, k + 1:] - grid[:, :, 1:(-k)]) * B_km1[:, :, 1:]
# in case grid is degenerate
value = torch.nan_to_num(value)
return value
def coef2curve(x_eval, grid, coef, k, device="cpu"):
'''
converting B-spline coefficients to B-spline curves. Evaluate x on B-spline curves (summing up B_batch results over B-spline basis).
Args:
-----
x_eval : 2D torch.tensor
shape (batch, in_dim)
grid : 2D torch.tensor
shape (in_dim, G+2k). G: the number of grid intervals; k: spline order.
coef : 3D torch.tensor
shape (in_dim, out_dim, G+k)
k : int
the piecewise polynomial order of splines.
device : str
devicde
Returns:
--------
y_eval : 3D torch.tensor
shape (number of samples, in_dim, out_dim)
'''
b_splines = B_batch(x_eval, grid, k=k)
y_eval = torch.einsum('ijk,jlk->ijl', b_splines, coef.to(b_splines.device))
return y_eval
def curve2coef(x_eval, y_eval, grid, k, lamb=1e-8):
'''
converting B-spline curves to B-spline coefficients using least squares.
Args:
-----
x_eval : 2D torch.tensor
shape (in_dim, out_dim, number of samples)
y_eval : 2D torch.tensor
shape (in_dim, out_dim, number of samples)
grid : 2D torch.tensor
shape (in_dim, grid+2*k)
k : int
spline order
lamb : float
regularized least square lambda
Returns:
--------
coef : 3D torch.tensor
shape (in_dim, out_dim, G+k)
'''
batch = x_eval.shape[0]
in_dim = x_eval.shape[1]
out_dim = y_eval.shape[2]
n_coef = grid.shape[1] - k - 1
mat = B_batch(x_eval, grid, k)
mat = mat.permute(1,0,2)[:,None,:,:].expand(in_dim, out_dim, batch, n_coef)
y_eval = y_eval.permute(1,2,0).unsqueeze(dim=3)
device = mat.device
#coef = torch.linalg.lstsq(mat, y_eval,
#driver='gelsy' if device == 'cpu' else 'gels').solution[:,:,:,0]
XtX = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), mat)
Xty = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), y_eval)
n1, n2, n = XtX.shape[0], XtX.shape[1], XtX.shape[2]
identity = torch.eye(n,n)[None, None, :, :].expand(n1, n2, n, n).to(device)
A = XtX + lamb * identity
B = Xty
coef = (A.pinverse() @ B)[:,:,:,0]
return coef
def extend_grid(grid, k_extend=0):
'''
extend grid
'''
h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1)
for i in range(k_extend):
grid = torch.cat([grid[:, [0]] - h, grid], dim=1)
grid = torch.cat([grid, grid[:, [-1]] + h], dim=1)
return grid