Rethink the diff
function
At the core of the neurodiffeq
library is the diff(x, t)
function, which computes the partial derivative ∂x/∂t
evaluated at t
. Usually, both tensor t
and x
have shapes of (n_samples, 1)
. When either x.shape
or t.shape
is malformed, however, there are cases where things could go wrong due to broadcasting. Such cases are so subtle that they have gone unnoticed for a long time.
All our generators (as defined in neurodiffeq.generator
) currently return tensors with shapes (n_samples,)
instead of (n_samples, 1)
. Efforts should be put into unifying the tensor shapes everywhere.
Here are two simple cases for review.
Case 1: Shapes don't matter
In this case, we try different combinations of x.shape
and t.shape
and check the shape of the output ∂x/∂t
, namely:
[n, 1]
and [n]
--> [n]
[n]
and [n]
--> [n]
[n, 1]
and [n, 1]
--> [n,1]
[n]
and [n, 1]
--> [n,1]
To see this, run the following code. Note that d1
, d2
, d3
, and d4
, while having different shapes, hold the same values. This is the reason why we incorrectly believed in the soundness of the diff()
function.
n = 10
t = torch.rand(n, requires_grad=True)
x = torch.sin(t)
d1 = diff(x.reshape(-1, 1), t)
d2 = diff(x.reshape(-1), t)
t = t.reshape(-1, 1)
x = torch.sin(t)
d3 = diff(x.reshape(-1, 1), t)
d4 = diff(x.reshape(-1), t)
Case 2: Shapes matter
In this second case, we examine two new operators – div
and curl
in spherical coordinates – and show that only when x.shape
and t.shape
are both (n, 1)
will the vector identity div(curl(...)) == 0
hold.
Here is the definition of curl and divergence in spherical coordinates
# these two operators have been recently implemented in neurodiffeq.operators
def spherical_curl(u_r, u_theta, u_phi, r, theta, phi):
d_r = lambda u: diff(u, r)
d_theta = lambda u: diff(u, theta)
d_phi = lambda u: diff(u, phi)
curl_r = (d_theta(u_phi * sin(theta)) - d_phi(u_theta)) / (r * sin(theta))
curl_theta = (d_phi(u_r) / sin(theta) - d_r(u_phi * r)) / r
curl_phi = (d_r(u_theta * r) - d_theta(u_r)) / r
return curl_r, curl_theta, curl_phi
def spherical_div(u_r, u_theta, u_phi, r, theta, phi):
div_r = diff(u_r * r ** 2, r) / r ** 2
div_theta = diff(u_theta * sin(theta), theta) / (r * sin(theta))
div_phi = diff(u_phi, phi) / (r * sin(theta))
return div_r + div_theta + div_phi
Here we define a vector field q
by specifying the rule to compute q
given coordinates (r, theta, phi)
def compute_q(r, theta, phi):
r_theta_phi = torch.stack([r.flatten(), theta.flatten(), phi.flatten()], dim=1)
W = torch.tensor([
[.01, .04, .07],
[.02, .05, .08],
[.03, .06, .09],
])
q = torch.matmul(r_theta_phi, W)
q = torch.tanh(q)
return q[:, 0], q[:, 1], q[:, 2]
We then test the vector identity div(curl(q)) == 0
for q
n = 10
# create r, theta, and phi with shape (n, 1)
r = torch.rand(n, 1, requires_grad=True) + 0.1
theta = torch.rand(n, 1, requires_grad=True) * np.pi
phi = torch.rand(n, 1, requires_grad=True) * np.pi * 2
q_r, q_theta, q_phi = compute_q(r, theta, phi)
# bind the operators to the r, theta, phi created above
div = lambda u_r, u_theta, u_phi: spherical_div(u_r, u_theta, u_phi, r, theta, phi)
curl = lambda u_r, u_theta, u_phi: spherical_curl(u_r, u_theta, u_phi, r, theta, phi)
div_curl_q1 = div(*curl(q_r.reshape(-1, 1), q_theta.reshape(-1, 1), q_phi.reshape(-1, 1)))
div_curl_q2 = div(*curl(q_r.reshape(-1), q_theta.reshape(-1), q_phi.reshape(-1)))
# create r, theta, and phi with shape (n,)
r = r.reshape(-1)
theta = r.reshape(-1)
phi = r.reshape(-1)
q_r, q_theta, q_phi = compute_q(r, theta, phi)
# bind the operators to the r, theta, phi created above
div = lambda u_r, u_theta, u_phi: spherical_div(u_r, u_theta, u_phi, r, theta, phi)
curl = lambda u_r, u_theta, u_phi: spherical_curl(u_r, u_theta, u_phi, r, theta, phi)
div_curl_q3 = div(*curl(q_r.reshape(-1, 1), q_theta.reshape(-1, 1), q_phi.reshape(-1, 1)))
div_curl_q4 = div(*curl(q_r.reshape(-1), q_theta.reshape(-1), q_phi.reshape(-1)))
print(div_curl_q1, div_curl_q2, div_curl_q3, div_curl_q4, sep="\n")
Printing all four div_curl_q
s will show that, only div_curl_q1
is (approximately) equal to 0, which means both the dependent and independent variables must have shape (n, 1)
for the differentiation to go correctly.