import torch # Required package for torchint
import torchint
data_type = torch.float64
device_type = 'cuda'
torchint.set_backend(data_type, device_type) # This sets single precision data type, and device in the backend
def function(x1, x2, x3, params): # this is the standard way to define an integrand with parameters
a1 = params[0]
a2 = params[1]
a3 = params[2]
return a1 * torch.exp(-a2 * (x1**2 + x2**2 + x3**2)) + a3 * torch.sin(x1) * torch.cos(x2) * torch.exp(x3)
# This sets the parameter set, which is a 2d array in all cases. In this case, we have 1e4 parameter sets
a1_values = torch.linspace(1.0, 10.0, 10000, dtype = data_type, device = device_type)
a2_values = torch.linspace(2.0, 20.0, 10000, dtype = data_type, device = device_type)
a3_values = torch.linspace(0.5, 5, 10000, dtype = data_type, device = device_type)
param_values = torch.stack((a1_values, a2_values, a3_values), dim=1)
bound = [[0, 1], [0, 1], [0, 1]] # This sets integral limitation as (0,1),(0,1), and (0,1) for x1, x2, and x3, respectively.
num_point = [20, 20, 20] # This sets number of sampling points per dimension.
def boundary(x1, x2, x3):
condition1 = x1**2 + x2**2 + x3**2 > 0.2
condition2 = x1**2 + x2**2 + x3**2 < 0.8
return condition1 & condition2
integral_value = torchint.trapz_integrate(function, param_values, bound, num_point, boundary) # We use trapz_integrate function
print(f"integral value: {integral_value}") # Output integral value
print(f"length of integral value: {integral_value.size()}") # Output length of the integral value
# To estimate error, we double the grids in all three dimension, and output the relative error.
num_point = [40, 40, 40] # This sets number of sampling points per dimension, which are doubled
integral_value2 = torchint.trapz_integrate(function, param_values, bound, num_point, boundary) #We use trapz_integrate function
relative_error = torch.abs(integral_value - integral_value2) / integral_value # relative error
print(f"integral value with denser grids: {integral_value2}")
print(f"relative error: {relative_error}")
print(integral_value.dtype)
print(integral_value.device)