diff --git a/gpu4pyscf/lib/multigrid/multigrid_v2/screen.cu b/gpu4pyscf/lib/multigrid/multigrid_v2/screen.cu index b6574d95d..39ebbcc39 100644 --- a/gpu4pyscf/lib/multigrid/multigrid_v2/screen.cu +++ b/gpu4pyscf/lib/multigrid/multigrid_v2/screen.cu @@ -206,4 +206,34 @@ void put_pairs_on_blocks( checkCudaErrors(cudaPeekAtLastError()); } + +int tailor_gaussian_pairs( + int *sorted_pairs_per_local_grid, int *n_pairs_per_local_grid, + const int i_angular, const int j_angular, const int *non_trivial_pairs, + const int *i_shells, const int *j_shells, const int n_j_shells, + const int *shell_to_ao_indices, + const int *accumulated_n_pairs_per_local_grid, + const int *sorted_block_index, const int n_contributing_blocks, + const int *image_indices, const double *vectors_to_neighboring_images, + const int n_images, const int *mesh, const int *atm, const int *bas, + const double *env, const int is_non_orthogonal, const double threshold, + const int derivative_order) { + if (is_non_orthogonal) { + return gpu4pyscf::gpbc::multi_grid::tailor_gaussian_pairs_driver( + sorted_pairs_per_local_grid, n_pairs_per_local_grid, i_angular, + j_angular, non_trivial_pairs, i_shells, j_shells, n_j_shells, + shell_to_ao_indices, accumulated_n_pairs_per_local_grid, + sorted_block_index, n_contributing_blocks, image_indices, + vectors_to_neighboring_images, n_images, mesh, atm, bas, env, threshold, + derivative_order); + } else { + return gpu4pyscf::gpbc::multi_grid::tailor_gaussian_pairs_driver( + sorted_pairs_per_local_grid, n_pairs_per_local_grid, i_angular, + j_angular, non_trivial_pairs, i_shells, j_shells, n_j_shells, + shell_to_ao_indices, accumulated_n_pairs_per_local_grid, + sorted_block_index, n_contributing_blocks, image_indices, + vectors_to_neighboring_images, n_images, mesh, atm, bas, env, threshold, + derivative_order); + } +} } diff --git a/gpu4pyscf/lib/multigrid/multigrid_v2/screening.cuh b/gpu4pyscf/lib/multigrid/multigrid_v2/screening.cuh index b9adb4ea6..7abdb14a5 100644 --- a/gpu4pyscf/lib/multigrid/multigrid_v2/screening.cuh +++ b/gpu4pyscf/lib/multigrid/multigrid_v2/screening.cuh @@ -26,7 +26,7 @@ #define EIJ_CUTOFF 60 #define BLOCK_DIM_XYZ 4 -#define EXP_OVERFLOW 400 +#define EXP_OVERFLOW 400 namespace gpu4pyscf::gpbc::multi_grid { @@ -644,4 +644,318 @@ __global__ void put_pairs_on_blocks_kernel( } } +template +__global__ static void tailor_gaussian_pairs_kernel( + int *sorted_pairs_per_local_grid, int *n_pairs_per_local_grid, + const int *non_trivial_pairs, const int *i_shells, const int *j_shells, + const int n_j_shells, const int *shell_to_ao_indices, + const int *accumulated_n_pairs_per_local_grid, + const int *sorted_block_index, const int *image_indices, + const double *vectors_to_neighboring_images, const int n_images, + const int mesh_a, const int mesh_b, const int mesh_c, const int *atm, + const int *bas, const double *env, const double threshold, + const int derivative_order) { + constexpr int n_threads = BLOCK_DIM_XYZ * BLOCK_DIM_XYZ * BLOCK_DIM_XYZ; + + const int block_index = sorted_block_index[blockIdx.x]; + const int n_blocks_b = (mesh_b + BLOCK_DIM_XYZ - 1) / BLOCK_DIM_XYZ; + const int n_blocks_c = (mesh_c + BLOCK_DIM_XYZ - 1) / BLOCK_DIM_XYZ; + + const int block_a_stride = n_blocks_b * n_blocks_c; + const int block_a_index = block_index / block_a_stride; + const int block_ab_index = block_index % block_a_stride; + const int block_b_index = block_ab_index / n_blocks_c; + const int block_c_index = block_ab_index % n_blocks_c; + + const int a_start = block_a_index * BLOCK_DIM_XYZ; + const int b_start = block_b_index * BLOCK_DIM_XYZ; + const int c_start = block_c_index * BLOCK_DIM_XYZ; + + const double start_position_x = + dxyz_dabc[0] * a_start + dxyz_dabc[3] * b_start + dxyz_dabc[6] * c_start; + const double start_position_y = + dxyz_dabc[1] * a_start + dxyz_dabc[4] * b_start + dxyz_dabc[7] * c_start; + const double start_position_z = + dxyz_dabc[2] * a_start + dxyz_dabc[5] * b_start + dxyz_dabc[8] * c_start; + + const double a_dot_b = dxyz_dabc[0] * dxyz_dabc[3] + + dxyz_dabc[1] * dxyz_dabc[4] + + dxyz_dabc[2] * dxyz_dabc[5]; + const double a_dot_c = dxyz_dabc[0] * dxyz_dabc[6] + + dxyz_dabc[1] * dxyz_dabc[7] + + dxyz_dabc[2] * dxyz_dabc[8]; + const double b_dot_c = dxyz_dabc[3] * dxyz_dabc[6] + + dxyz_dabc[4] * dxyz_dabc[7] + + dxyz_dabc[5] * dxyz_dabc[8]; + + const int a_upper = min(a_start + BLOCK_DIM_XYZ, mesh_a) - a_start; + const int b_upper = min(b_start + BLOCK_DIM_XYZ, mesh_b) - b_start; + const int c_upper = min(c_start + BLOCK_DIM_XYZ, mesh_c) - c_start; + + const int thread_id = threadIdx.x + threadIdx.y * BLOCK_DIM_XYZ + + threadIdx.z * BLOCK_DIM_XYZ * BLOCK_DIM_XYZ; + + const int start_pair_index = accumulated_n_pairs_per_local_grid[block_index]; + const int end_pair_index = + accumulated_n_pairs_per_local_grid[block_index + 1]; + const int n_pairs = end_pair_index - start_pair_index; + const int n_batches = (n_pairs + n_threads - 1) / n_threads; + + for (int i_batch = 0, i_pair_index = start_pair_index + thread_id; + i_batch < n_batches; i_batch++, i_pair_index += n_threads) { + const bool is_valid_pair = i_pair_index < end_pair_index; + const int i_pair = + is_valid_pair ? sorted_pairs_per_local_grid[i_pair_index] : 0; + + const int image_index = image_indices[i_pair]; + const int image_index_i = image_index / n_images; + const int image_index_j = image_index % n_images; + + const int shell_pair_index = non_trivial_pairs[i_pair]; + const int i_shell_index = shell_pair_index / n_j_shells; + const int j_shell_index = shell_pair_index % n_j_shells; + const int i_shell = i_shells[i_shell_index]; + const int j_shell = j_shells[j_shell_index]; + + const double i_exponent = env[bas(PTR_EXP, i_shell)]; + const int i_coord_offset = atm(PTR_COORD, bas(ATOM_OF, i_shell)); + const double i_x = + env[i_coord_offset] + vectors_to_neighboring_images[image_index_i * 3]; + const double i_y = env[i_coord_offset + 1] + + vectors_to_neighboring_images[image_index_i * 3 + 1]; + const double i_z = env[i_coord_offset + 2] + + vectors_to_neighboring_images[image_index_i * 3 + 2]; + const double i_coeff = env[bas(PTR_COEFF, i_shell)]; + + const double j_exponent = env[bas(PTR_EXP, j_shell)]; + const int j_coord_offset = atm(PTR_COORD, bas(ATOM_OF, j_shell)); + const double j_x = + env[j_coord_offset] + vectors_to_neighboring_images[image_index_j * 3]; + const double j_y = env[j_coord_offset + 1] + + vectors_to_neighboring_images[image_index_j * 3 + 1]; + const double j_z = env[j_coord_offset + 2] + + vectors_to_neighboring_images[image_index_j * 3 + 2]; + const double j_coeff = env[bas(PTR_COEFF, j_shell)]; + + const double ij_exponent = i_exponent + j_exponent; + const double ij_exponent_in_prefactor = + i_exponent * j_exponent / ij_exponent * + distance_squared(i_x - j_x, i_y - j_y, i_z - j_z); + + const double pair_x = (i_exponent * i_x + j_exponent * j_x) / ij_exponent; + const double pair_y = (i_exponent * i_y + j_exponent * j_y) / ij_exponent; + const double pair_z = (i_exponent * i_z + j_exponent * j_z) / ij_exponent; + + const double x0 = start_position_x - pair_x; + const double y0 = start_position_y - pair_y; + const double z0 = start_position_z - pair_z; + + const double gaussian_exponent_at_reference = + ij_exponent * distance_squared(x0, y0, z0); + + const double pair_prefactor = i_coeff * j_coeff * + common_fac_sp() * + common_fac_sp(); + + const double gaussian_starting_point = + is_valid_pair + ? exp(-(ij_exponent_in_prefactor + gaussian_exponent_at_reference) / + 3.0) + : 0; + + const double da_squared = + distance_squared(dxyz_dabc[0], dxyz_dabc[1], dxyz_dabc[2]); + const double db_squared = + distance_squared(dxyz_dabc[3], dxyz_dabc[4], dxyz_dabc[5]); + const double dc_squared = + distance_squared(dxyz_dabc[6], dxyz_dabc[7], dxyz_dabc[8]); + + const double exp_da_squared = exp(-2 * ij_exponent * da_squared); + const double exp_db_squared = exp(-2 * ij_exponent * db_squared); + const double exp_dc_squared = exp(-2 * ij_exponent * dc_squared); + + const double cross_term_a = + dxyz_dabc[0] * x0 + dxyz_dabc[1] * y0 + dxyz_dabc[2] * z0; + const double cross_term_b = + dxyz_dabc[3] * x0 + dxyz_dabc[4] * y0 + dxyz_dabc[5] * z0; + const double cross_term_c = + dxyz_dabc[6] * x0 + dxyz_dabc[7] * y0 + dxyz_dabc[8] * z0; + + const double recursion_factor_a_start = + exp(-ij_exponent * (2 * cross_term_a + da_squared)); + const double recursion_factor_b_start = + exp(-ij_exponent * (2 * cross_term_b + db_squared)); + const double recursion_factor_c_start = + exp(-ij_exponent * (2 * cross_term_c + dc_squared)); + + const double exp_dadb = exp(-2 * ij_exponent * a_dot_b); + const double exp_dadc = exp(-2 * ij_exponent * a_dot_c); + const double exp_dbdc = exp(-2 * ij_exponent * b_dot_c); + + int a_index, b_index, c_index; + double x, y, z; + double gaussian_x, gaussian_y, gaussian_z, recursion_factor_a, + recursion_factor_b, recursion_factor_c; + double recursion_factor_ab_pow_a = 1; + double recursion_factor_ac_pow_a = 1; + double recursion_factor_bc_pow_b = 1; + + if constexpr (is_non_orthogonal) { + // recursion_factor_ab_pow_a = 1; + // recursion_factor_ac_pow_a = 1; + } else { + x = start_position_x; + } + + double max_gaussian_value = 0; + + for (a_index = 0, gaussian_x = gaussian_starting_point, + recursion_factor_a = recursion_factor_a_start; + a_index < a_upper; a_index++, gaussian_x *= recursion_factor_a, + recursion_factor_a *= exp_da_squared) { + + if constexpr (is_non_orthogonal) { + recursion_factor_bc_pow_b = 1; + } else { + y = start_position_y; + } + for (b_index = 0, gaussian_y = gaussian_starting_point, + recursion_factor_b = recursion_factor_b_start; + b_index < b_upper; b_index++, + gaussian_y *= recursion_factor_b * recursion_factor_ab_pow_a, + recursion_factor_b *= exp_db_squared) { + + if constexpr (is_non_orthogonal) { + x = start_position_x + a_index * dxyz_dabc[0] + + b_index * dxyz_dabc[3]; + y = start_position_y + a_index * dxyz_dabc[1] + + b_index * dxyz_dabc[4]; + z = start_position_z + a_index * dxyz_dabc[2] + + b_index * dxyz_dabc[5]; + } else { + z = start_position_z; + } + for (c_index = 0, gaussian_z = gaussian_starting_point, + recursion_factor_c = recursion_factor_c_start; + c_index < c_upper; c_index++, + gaussian_z *= recursion_factor_c * recursion_factor_ac_pow_a * + recursion_factor_bc_pow_b, + recursion_factor_c *= exp_dc_squared) { + + const double r_i = sqrt(distance_squared(x - i_x, y - i_y, z - i_z)); + const double r_j = sqrt(distance_squared(x - j_x, y - j_y, z - j_z)); + const double r_p = + sqrt(distance_squared(x - pair_x, y - pair_y, z - pair_z)); + + const double approxmate_polynomial = + approximate_polynomial_value( + r_i, r_j, r_p, derivative_order); + + const double gaussian = gaussian_x * gaussian_y * gaussian_z; + + const double approximate_value = + abs(4.0 * M_PI * r_p * r_p * pair_prefactor * + approxmate_polynomial * gaussian); + + max_gaussian_value = max(max_gaussian_value, approximate_value); + + if constexpr (is_non_orthogonal) { + x += dxyz_dabc[6]; + y += dxyz_dabc[7]; + z += dxyz_dabc[8]; + } else { + z += dxyz_dabc[8]; + } + } + + if constexpr (is_non_orthogonal) { + recursion_factor_bc_pow_b *= exp_dbdc; + } else { + y += dxyz_dabc[4]; + } + } + + if constexpr (is_non_orthogonal) { + recursion_factor_ab_pow_a *= exp_dadb; + recursion_factor_ac_pow_a *= exp_dadc; + } else { + x += dxyz_dabc[0]; + } + } + + if (max_gaussian_value < threshold && is_valid_pair) { + sorted_pairs_per_local_grid[i_pair_index] = -1; + atomicAdd(n_pairs_per_local_grid + block_index, -1); + } + } +} + +#define tailor_gaussian_pairs_kernel_macro(li, lj) \ + tailor_gaussian_pairs_kernel \ + <<>>( \ + sorted_pairs_per_local_grid, n_pairs_per_local_grid, \ + non_trivial_pairs, i_shells, j_shells, n_j_shells, \ + shell_to_ao_indices, accumulated_n_pairs_per_local_grid, \ + sorted_block_index, image_indices, vectors_to_neighboring_images, \ + n_images, mesh_a, mesh_b, mesh_c, atm, bas, env, threshold, \ + derivative_order); + +#define tailor_gaussian_pairs_kernel_case_macro(li, lj) \ + case (li * 10 + lj): \ + tailor_gaussian_pairs_kernel_macro(li, lj); \ + break + +template +int tailor_gaussian_pairs_driver( + int *sorted_pairs_per_local_grid, int *n_pairs_per_local_grid, + const int i_angular, const int j_angular, const int *non_trivial_pairs, + const int *i_shells, const int *j_shells, const int n_j_shells, + const int *shell_to_ao_indices, + const int *accumulated_n_pairs_per_local_grid, + const int *sorted_block_index, const int n_contributing_blocks, + const int *image_indices, const double *vectors_to_neighboring_images, + const int n_images, const int *mesh, const int *atm, const int *bas, + const double *env, const double threshold, const int derivative_order) { + dim3 block_size(BLOCK_DIM_XYZ, BLOCK_DIM_XYZ, BLOCK_DIM_XYZ); + int mesh_a = mesh[0]; + int mesh_b = mesh[1]; + int mesh_c = mesh[2]; + dim3 block_grid(n_contributing_blocks, 1, 1); + switch (i_angular * 10 + j_angular) { + tailor_gaussian_pairs_kernel_case_macro(0, 0); + tailor_gaussian_pairs_kernel_case_macro(0, 1); + tailor_gaussian_pairs_kernel_case_macro(0, 2); + tailor_gaussian_pairs_kernel_case_macro(0, 3); + tailor_gaussian_pairs_kernel_case_macro(0, 4); + tailor_gaussian_pairs_kernel_case_macro(1, 0); + tailor_gaussian_pairs_kernel_case_macro(1, 1); + tailor_gaussian_pairs_kernel_case_macro(1, 2); + tailor_gaussian_pairs_kernel_case_macro(1, 3); + tailor_gaussian_pairs_kernel_case_macro(1, 4); + tailor_gaussian_pairs_kernel_case_macro(2, 0); + tailor_gaussian_pairs_kernel_case_macro(2, 1); + tailor_gaussian_pairs_kernel_case_macro(2, 2); + tailor_gaussian_pairs_kernel_case_macro(2, 3); + tailor_gaussian_pairs_kernel_case_macro(2, 4); + tailor_gaussian_pairs_kernel_case_macro(3, 0); + tailor_gaussian_pairs_kernel_case_macro(3, 1); + tailor_gaussian_pairs_kernel_case_macro(3, 2); + tailor_gaussian_pairs_kernel_case_macro(3, 3); + tailor_gaussian_pairs_kernel_case_macro(3, 4); + tailor_gaussian_pairs_kernel_case_macro(4, 0); + tailor_gaussian_pairs_kernel_case_macro(4, 1); + tailor_gaussian_pairs_kernel_case_macro(4, 2); + tailor_gaussian_pairs_kernel_case_macro(4, 3); + tailor_gaussian_pairs_kernel_case_macro(4, 4); + default: + fprintf(stderr, + "angular momentum pair %d, %d is not supported in " + "evaluate_density_driver\n", + i_angular, j_angular); + return 1; + } + + return checkCudaErrors(cudaPeekAtLastError()); +} + } // namespace gpu4pyscf::gpbc::multi_grid diff --git a/gpu4pyscf/lib/multigrid/multigrid_v2/utils.cuh b/gpu4pyscf/lib/multigrid/multigrid_v2/utils.cuh index cd774255d..e7b063117 100644 --- a/gpu4pyscf/lib/multigrid/multigrid_v2/utils.cuh +++ b/gpu4pyscf/lib/multigrid/multigrid_v2/utils.cuh @@ -26,6 +26,36 @@ __host__ __device__ T distance_squared(const T x, const T y, const T z) { return x * x + y * y + z * z; } +template +__host__ __device__ T approximate_polynomial_value(const double r_i, + const double r_j, + const double r_p, + const int derivative_order) { + + T result = pow(r_i, i_angular) * pow(r_j, j_angular); + + if (derivative_order > 0) { + result *= 2.0 * r_p; + if constexpr (i_angular > 0) { + result += i_angular * pow(r_i, i_angular - 1) * pow(r_j, j_angular); + } + + if constexpr (j_angular > 0) { + result += j_angular * pow(r_i, i_angular) * pow(r_j, j_angular - 1); + } + } + + if (derivative_order > 1) { + result *= 2.0 * r_p; + if constexpr (i_angular > 0 && j_angular > 0) { + result += i_angular * j_angular * pow(r_i, i_angular - 1) * + pow(r_j, j_angular - 1); + } + } + + return result; +} + template __device__ constexpr T common_fac_sp() { if constexpr (ANG == 0) { return 0.282094791773878143; diff --git a/gpu4pyscf/pbc/dft/multigrid_v2.py b/gpu4pyscf/pbc/dft/multigrid_v2.py index 14b04c27f..00b73e635 100644 --- a/gpu4pyscf/pbc/dft/multigrid_v2.py +++ b/gpu4pyscf/pbc/dft/multigrid_v2.py @@ -72,6 +72,137 @@ def ifft_in_place(x): return fft.ifftn(x, axes=(-3, -2, -1), overwrite_x=True) +def iG_density(density, cell): + mesh = cell.mesh + b = cp.asarray(cell.reciprocal_vectors()) + + mesh_in_int32 = np.array(mesh, dtype=np.int32) + block_dim = (8, 8, 8) + grid_dim = tuple(np.array(np.ceil(mesh_in_int32 / np.array(block_dim)), dtype=np.int32)) + + result = cp.zeros((3, *mesh), dtype=cp.complex128).reshape(3, -1) + + custom_kernel = cp.RawKernel( + r""" + #include + extern "C" __global__ + void kernel(complex * result, + const complex * density, + const double * reciprocal_lattice, + const int mesh_a, + const int mesh_b, + const int mesh_c) { + int a = blockIdx.x * blockDim.x + threadIdx.x; + int b = blockIdx.y * blockDim.y + threadIdx.y; + int c = blockIdx.z * blockDim.z + threadIdx.z; + + if(a >= mesh_a || b >= mesh_b || c >= mesh_c) { + return; + } + + const size_t grid_index = a * mesh_b * mesh_c + b * mesh_c + c; + result += grid_index; + const complex density_value = density[grid_index] * complex{0.0, 1.0}; + + if(a >= (mesh_a + 1) / 2) a -= mesh_a; + if(b >= (mesh_b + 1) / 2) b -= mesh_b; + if(c >= (mesh_c + 1) / 2) c -= mesh_c; + + int n_grid = mesh_a * mesh_b * mesh_c; + + *result = (a * reciprocal_lattice[0] + + b * reciprocal_lattice[3] + + c * reciprocal_lattice[6]) * density_value; + result += n_grid; + *result = (a * reciprocal_lattice[1] + + b * reciprocal_lattice[4] + + c * reciprocal_lattice[7]) * density_value; + result += n_grid; + *result = (a * reciprocal_lattice[2] + + b * reciprocal_lattice[5] + + c * reciprocal_lattice[8]) * density_value; + } + """, + 'kernel', + ) + + custom_kernel( + grid_dim, + block_dim, + (result, density, b, mesh_in_int32[0], mesh_in_int32[1], mesh_in_int32[2]), + ) + + return result + + +def contract_iG_potential(gga_potential, cell): + mesh = cell.mesh + b = cp.asarray(cell.reciprocal_vectors()) + + mesh_in_int32 = np.array(mesh, dtype=np.int32) + block_dim = (8, 8, 8) + grid_dim = tuple(np.array(np.ceil(mesh_in_int32 / np.array(block_dim)), dtype=np.int32)) + + custom_kernel = cp.RawKernel( + r""" + #include + extern "C" __global__ + void kernel(complex * gga_potential, + const double * reciprocal_lattice, + const int mesh_a, + const int mesh_b, + const int mesh_c) { + int a = blockIdx.x * blockDim.x + threadIdx.x; + int b = blockIdx.y * blockDim.y + threadIdx.y; + int c = blockIdx.z * blockDim.z + threadIdx.z; + + if(a >= mesh_a || b >= mesh_b || c >= mesh_c) { + return; + } + + int n_grid = mesh_a * mesh_b * mesh_c; + + const size_t grid_index = a * mesh_b * mesh_c + b * mesh_c + c; + gga_potential += grid_index; + + if(a >= (mesh_a + 1) / 2) a -= mesh_a; + if(b >= (mesh_b + 1) / 2) b -= mesh_b; + if(c >= (mesh_c + 1) / 2) c -= mesh_c; + complex potential_change = 0; + potential_change += + (a * reciprocal_lattice[0] + + b * reciprocal_lattice[3] + + c * reciprocal_lattice[6]) * + complex{0.0, 1.0} * + gga_potential[n_grid]; + + potential_change += + (a * reciprocal_lattice[1] + + b * reciprocal_lattice[4] + + c * reciprocal_lattice[7]) * + complex{0.0, 1.0} * + gga_potential[2 * n_grid]; + + potential_change += + (a * reciprocal_lattice[2] + + b * reciprocal_lattice[5] + + c * reciprocal_lattice[8]) * + complex{0.0, 1.0} * + gga_potential[3 * n_grid]; + + *gga_potential -= potential_change; + } + """, + 'kernel', + ) + + custom_kernel( + grid_dim, + block_dim, + (gga_potential, b, mesh_in_int32[0], mesh_in_int32[1], mesh_in_int32[2]), + ) + + def unique_with_sort(x): # This function does the same thing as cp.unique(x, return_inverse=True). # It's not super optimized, but for whatever reason, cp.unique is very slow, so this one is better. @@ -95,6 +226,131 @@ def unique_with_sort(x): return x, inverse_unique[inverse_sort] +def unique_with_multiple_keys(x): + # This function expands the previous function to handle multiple keys + # shaped as [ (1, 2), (3, -4), ....] + assert type(x) is cp.ndarray and (x.dtype == cp.int32 or x.dtype == cp.int64) and x.ndim == 2 + x = x.T + n = x.shape[-1] + + inverse_sort = cp.zeros(n, dtype=cp.int64) + if n <= 1: + return x, inverse_sort + + sort_index = cp.lexsort(x) + inverse_sort[sort_index] = cp.arange(0, n, dtype=cp.int64) + x = x[:, sort_index].T + + mask = cp.empty(n, dtype=cp.bool_) + mask[0] = True + mask[1:] = cp.any(x[1:] != x[:-1], axis=-1) + + x = x[mask] + inverse_unique = cp.cumsum(mask, dtype=cp.int64) - 1 + + return x, inverse_unique[inverse_sort] + + +def sort_contraction_coefficients(coeff): + contraction_shapes = cp.array([i.shape for i in coeff]) + unique_shapes, inverse = unique_with_multiple_keys(contraction_shapes) + unique_shapes = unique_shapes.get() + inverse = inverse.get() + + sliced_axis = np.zeros((len(coeff) + 1, 2), dtype=np.int32) + sliced_axis[1:] = np.cumsum(np.array([i.shape for i in coeff]), axis=0, dtype=np.int32) + + left_basis_function_indices = [cp.arange(begin, end) for begin, end in zip(sliced_axis[:-1, 0], sliced_axis[1:, 0])] + right_basis_function_indices = [ + cp.arange(begin, end) for begin, end in zip(sliced_axis[:-1, 1], sliced_axis[1:, 1]) + ] + + sorted_coeffs = [{'shape': shape, 'coeffs': [], 'left_indices': [], 'right_indices': []} for shape in unique_shapes] + + for category, coeffs, left, right in zip(inverse, coeff, left_basis_function_indices, right_basis_function_indices): + sorted_coeffs[category]['coeffs'].append(coeffs) + sorted_coeffs[category]['left_indices'].append(left) + sorted_coeffs[category]['right_indices'].append(right) + + for category in sorted_coeffs: + category['coeffs'] = cp.array(category['coeffs']) + category['left_indices'] = cp.concatenate(category['left_indices']) + category['right_indices'] = cp.concatenate(category['right_indices']) + + return sorted_coeffs, sliced_axis[-1] + + +def contracted_to_primitive(batched_matrices, sorted_coeffs_left, sorted_coeffs_right, primitive_shape): + assert len(batched_matrices.shape) == 3 + + n_slices = batched_matrices.shape[0] + n_cols = batched_matrices.shape[2] + + n_rows_primitive = primitive_shape[0] + n_cols_primitive = primitive_shape[1] + + intermediate_shape = (n_slices, n_rows_primitive, n_cols) + intermediate = cp.zeros(intermediate_shape, dtype=batched_matrices.dtype) + + for i in sorted_coeffs_left: + subarray_shape = (n_slices, -1, i['shape'][1], n_cols) + intermediate[:, i['left_indices']] = cp.einsum( + 'naij, api -> napj', + batched_matrices[:, i['right_indices']].reshape(subarray_shape), + i['coeffs'], + ).reshape(n_slices, -1, n_cols) + + intermediate = intermediate.transpose(0, 2, 1) + + result_shape = (n_slices, n_cols_primitive, n_rows_primitive) + result = cp.zeros(result_shape, dtype=batched_matrices.dtype) + + for i in sorted_coeffs_right: + subarray_shape = (n_slices, -1, i['shape'][1], n_rows_primitive) + result[:, i['left_indices']] = cp.einsum( + 'najp, aqj -> naqp', + intermediate[:, i['right_indices']].reshape(subarray_shape), + i['coeffs'], + ).reshape(n_slices, -1, n_rows_primitive) + + return result.transpose(0, 2, 1) + + +def primitive_to_contracted(batched_matrices, sorted_coeffs_left, sorted_coeffs_right, contracted_shape): + assert len(batched_matrices.shape) == 3 + + n_slices = batched_matrices.shape[0] + n_cols = batched_matrices.shape[2] + + n_rows_contracted = contracted_shape[0] + n_cols_contracted = contracted_shape[1] + + intermediate_shape = (n_slices, n_rows_contracted, n_cols) + intermediate = cp.zeros(intermediate_shape, dtype=batched_matrices.dtype) + + for i in sorted_coeffs_left: + subarray_shape = (n_slices, -1, i['shape'][0], n_cols) + intermediate[:, i['right_indices']] = cp.einsum( + 'napq, api -> naiq', + batched_matrices[:, i['left_indices']].reshape(subarray_shape), + i['coeffs'], + ).reshape(n_slices, -1, n_cols) + + intermediate = intermediate.transpose(0, 2, 1) + result_shape = (n_slices, n_cols_contracted, n_rows_contracted) + result = cp.zeros(result_shape, dtype=batched_matrices.dtype) + + for i in sorted_coeffs_right: + subarray_shape = (n_slices, -1, i['shape'][0], n_rows_contracted) + result[:, i['right_indices']] = cp.einsum( + 'naqi, aqj -> naji', + intermediate[:, i['left_indices']].reshape(subarray_shape), + i['coeffs'], + ).reshape(n_slices, -1, n_rows_contracted) + + return result.transpose(0, 2, 1) + + def image_pair_to_difference( vectors_to_neighboring_images, lattice_vectors, @@ -281,16 +537,22 @@ def assign_pairs_to_blocks( pairs_to_blocks_end, n_blocks_abc, n_indices, + i_angular, + j_angular, non_trivial_pairs, i_shells, j_shells, + shell_to_ao_indices, image_indices, vectors_to_neighboring_images, mesh, atm, bas, env, - has_warned_instability + has_warned_instability, + is_non_orthogonal, + threshold, + derivative_order, ): n_blocks = np.prod(n_blocks_abc) n_pairs_on_blocks = cp.zeros(n_blocks + 1, dtype=cp.int32) @@ -352,6 +614,37 @@ def assign_pairs_to_blocks( cast_to_pointer(env) ) + # libgpbc.tailor_gaussian_pairs( + # cast_to_pointer(pairs_on_blocks), + # cast_to_pointer(n_pairs_on_blocks), + # ctypes.c_int(i_angular), + # ctypes.c_int(j_angular), + # cast_to_pointer(non_trivial_pairs), + # cast_to_pointer(i_shells), + # cast_to_pointer(j_shells), + # ctypes.c_int(len(j_shells)), + # cast_to_pointer(shell_to_ao_indices), + # cast_to_pointer(accumulated_n_pairs_per_block), + # cast_to_pointer(sorted_block_index), + # ctypes.c_int(n_contributing_blocks), + # cast_to_pointer(image_indices), + # cast_to_pointer(vectors_to_neighboring_images), + # ctypes.c_int(len(vectors_to_neighboring_images)), + # cast_to_pointer(mesh), + # cast_to_pointer(atm), + # cast_to_pointer(bas), + # cast_to_pointer(env), + # ctypes.c_int(is_non_orthogonal), + # ctypes.c_double(threshold), + # ctypes.c_int(derivative_order), + # ) + + # pairs_on_blocks = pairs_on_blocks[pairs_on_blocks >= 0] + # sorted_block_index = cp.asarray(cp.argsort(-n_pairs_on_blocks), dtype=cp.int32) + # n_contributing_blocks = cp.count_nonzero(n_pairs_on_blocks) + # accumulated_n_pairs_per_block[1:] = cp.cumsum(n_pairs_on_blocks, dtype=cp.int32) + # sorted_block_index = sorted_block_index[:n_contributing_blocks] + return ( pairs_on_blocks, accumulated_n_pairs_per_block, @@ -389,6 +682,12 @@ def sort_gaussian_pairs(mydf, xc_type="LDA"): t0 = log.timer("task generation", *t0) t1 = t0 + derivative_order = 0 + if xc_type == 'GGA': + derivative_order = 1 + if xc_type == 'MGGA': + derivative_order = 2 + pairs = [] for grids_localized, grids_diffused in tasks: subcell_in_localized_region = grids_localized.cell @@ -409,7 +708,7 @@ def sort_gaussian_pairs(mydf, xc_type="LDA"): libgpbc.update_dxyz_dabc(dxyz_dabc.ctypes) n_blocks_abc = np.asarray(np.ceil(mesh / block_size), dtype=cp.int32) equivalent_cell_in_localized, coeff_in_localized = ( - subcell_in_localized_region.decontract_basis(to_cart=True, aggregate=True) + subcell_in_localized_region.decontract_basis(to_cart=True) ) n_primitive_gtos_in_localized = multigrid._pgto_shells( @@ -424,25 +723,20 @@ def sort_gaussian_pairs(mydf, xc_type="LDA"): if grids_diffused is None: grouped_cell = equivalent_cell_in_localized - concatenated_coeff = scipy.linalg.block_diag(coeff_in_localized) + concatenated_coeff = coeff_in_localized else: subcell_in_diffused_region = grids_diffused.cell - equivalent_cell_in_diffused, coeff_in_diffused = ( - subcell_in_diffused_region.decontract_basis( - to_cart=True, aggregate=True - ) - ) + equivalent_cell_in_diffused, coeff_in_diffused = subcell_in_diffused_region.decontract_basis(to_cart=True) grouped_cell = equivalent_cell_in_localized + equivalent_cell_in_diffused - grouped_cell._bas[n_primitive_gtos_in_localized:, 0] -= len( - subcell_in_localized_region._atm - ) + grouped_cell._bas[n_primitive_gtos_in_localized:, 0] -= len(subcell_in_localized_region._atm) - concatenated_coeff = scipy.linalg.block_diag( - coeff_in_localized, coeff_in_diffused - ) - concatenated_coeff = cp.asarray(concatenated_coeff) + concatenated_coeff = coeff_in_localized + coeff_in_diffused + + coeff_in_localized, localized_shape = sort_contraction_coefficients(coeff_in_localized) + + concatenated_coeff, concatenated_shape = sort_contraction_coefficients(concatenated_coeff) n_primitive_gtos_in_two_regions = multigrid._pgto_shells(grouped_cell) rad = vol**(-1./3) * cell.rcut + 1 @@ -460,10 +754,7 @@ def sort_gaussian_pairs(mydf, xc_type="LDA"): else: ao_indices_in_diffused = cp.asarray(grids_diffused.ao_idx, dtype=cp.int32) - concatenated_ao_indices = cp.concatenate( - (ao_indices_in_localized, ao_indices_in_diffused) - ) - coeff_in_localized = cp.asarray(coeff_in_localized) + concatenated_ao_indices = cp.concatenate((ao_indices_in_localized, ao_indices_in_diffused)) per_angular_pairs = [] i_angulars = grouped_cell._bas[:n_primitive_gtos_in_localized, multigrid.ANG_OF] @@ -527,16 +818,22 @@ def sort_gaussian_pairs(mydf, xc_type="LDA"): pairs_to_blocks_end, n_blocks_abc, n_indices, + i_angular, + j_angular, screened_shell_pairs, i_shells, j_shells, + shell_to_ao_indices, image_indices, vectors_to_neighboring_images, mesh, atm, bas, env, - has_warned_instability + has_warned_instability, + is_non_orthogonal, + cell.precision, + derivative_order, ) t1 = log.timer_debug2( "assigning pairs to blocks in angular pair" @@ -559,21 +856,23 @@ def sort_gaussian_pairs(mydf, xc_type="LDA"): pairs.append( { - "per_angular_pairs": per_angular_pairs, - "neighboring_images": vectors_to_neighboring_images, - "grouped_cell": grouped_cell, - "mesh": mesh, # this one is on cpu memory - "fft_grid": fft_grid, - "ao_indices_in_localized": ao_indices_in_localized, - "ao_indices_in_diffused": ao_indices_in_diffused, - "concatenated_ao_indices": concatenated_ao_indices, - "coeff_in_localized": coeff_in_localized, - "concatenated_coeff": concatenated_coeff, - "atm": atm, - "bas": bas, - "env": env, - "dxyz_dabc": dxyz_dabc, - "is_non_orthogonal": is_non_orthogonal, + 'per_angular_pairs': per_angular_pairs, + 'neighboring_images': vectors_to_neighboring_images, + 'grouped_cell': grouped_cell, + 'mesh': mesh, # this one is on cpu memory + 'fft_grid': fft_grid, + 'ao_indices_in_localized': ao_indices_in_localized, + 'ao_indices_in_diffused': ao_indices_in_diffused, + 'concatenated_ao_indices': concatenated_ao_indices, + 'coeff_in_localized': coeff_in_localized, + 'concatenated_coeff': concatenated_coeff, + 'primitive_shape': (localized_shape[0], concatenated_shape[0]), + 'contracted_shape': (localized_shape[1], concatenated_shape[1]), + 'atm': atm, + 'bas': bas, + 'env': env, + 'dxyz_dabc': dxyz_dabc, + 'is_non_orthogonal': is_non_orthogonal, } ) @@ -711,23 +1010,25 @@ def evaluate_density_on_g_mesh(mydf, dm_kpts, kpts=None, xc_type='LDA'): ] n_ao_in_localized = density_matrix_with_rows_in_diffused.shape[3] - density_matrix_with_rows_in_localized[ - :, :, :, n_ao_in_localized: - ] += density_matrix_with_rows_in_diffused.transpose(0, 1, 3, 2).conj() - coeff_sandwiched_density_matrix = cp.einsum( - "nkij,pi->nkpj", - density_matrix_with_rows_in_localized, - pairs["coeff_in_localized"], + density_matrix_with_rows_in_localized[:, :, :, n_ao_in_localized:] += ( + density_matrix_with_rows_in_diffused.transpose(0, 1, 3, 2).conj() ) - coeff_sandwiched_density_matrix = cp.einsum( - "nkpj, qj -> nkpq", - coeff_sandwiched_density_matrix, - pairs["concatenated_coeff"], + n_sets, n_k_points = density_matrix_with_rows_in_localized.shape[:2] + + density_matrix_with_rows_in_localized = density_matrix_with_rows_in_localized.reshape( + -1, *pairs['contracted_shape'] ) - libgpbc.update_dxyz_dabc(pairs["dxyz_dabc"].ctypes) + coeff_sandwiched_density_matrix = contracted_to_primitive( + density_matrix_with_rows_in_localized, + pairs['coeff_in_localized'], + pairs['concatenated_coeff'], + pairs['primitive_shape'], + ).reshape(n_sets, n_k_points, *pairs['primitive_shape']) + + libgpbc.update_dxyz_dabc(pairs['dxyz_dabc'].ctypes) img_phase = image_phase_for_kpts(cell, pairs["neighboring_images"], kpts) density = ( @@ -764,9 +1065,10 @@ def evaluate_density_on_g_mesh(mydf, dm_kpts, kpts=None, xc_type='LDA'): ] += tau density_on_g_mesh = density_on_g_mesh.reshape([n_channels, density_slices, -1]) - if xc_type == 'GGA' or xc_type == 'MGGA': - density_on_g_mesh[:, 1:4] = pbc_tools._get_Gv(mydf.cell, mydf.mesh).T - density_on_g_mesh[:, 1:4] *= density_on_g_mesh[:, :1] * 1j + if xc_type != 'LDA': + for i in range(len(density_on_g_mesh)): + density_on_g_mesh[i, 1:4] = iG_density(density_on_g_mesh[i], cell) + return density_on_g_mesh _eval_rhoG = evaluate_density_on_g_mesh @@ -784,8 +1086,7 @@ def evaluate_xc_wrapper(pairs_info, xc_weights, img_phase, with_tau=False): c_driver = libgpbc.evaluate_xc_with_tau_driver else: c_driver = libgpbc.evaluate_xc_driver - n_i_functions = len(pairs_info["coeff_in_localized"]) - n_j_functions = len(pairs_info["concatenated_coeff"]) + n_i_functions, n_j_functions = pairs_info['primitive_shape'] phase_diff_among_images, image_pair_difference_index = img_phase n_k_points, n_difference_images = phase_diff_among_images.shape @@ -913,8 +1214,16 @@ def convert_xc_on_g_mesh_to_fock( libgpbc.update_dxyz_dabc(pairs["dxyz_dabc"].ctypes) img_phase = image_phase_for_kpts(cell, pairs["neighboring_images"], kpts) fock_slice = evaluate_xc_wrapper(pairs, interpolated_xc, img_phase, with_tau=with_tau) - fock_slice = cp.einsum("nkpq,pi->nkiq", fock_slice, pairs["coeff_in_localized"]) - fock_slice = cp.einsum("nkiq,qj->nkij", fock_slice, pairs["concatenated_coeff"]) + n_sets, n_k_points = fock_slice.shape[:2] + + fock_slice = fock_slice.reshape(-1, *pairs['primitive_shape']) + + fock_slice = primitive_to_contracted( + fock_slice, + pairs['coeff_in_localized'], + pairs['concatenated_coeff'], + pairs['contracted_shape'], + ).reshape(n_sets, n_k_points, *pairs['contracted_shape']) # While mathematically it is correct to have concatenated # ao indices in the addition, but it is possible that the ao @@ -1081,25 +1390,24 @@ def convert_xc_on_g_mesh_to_fock_gradient( ] n_ao_in_localized = density_matrix_slice.shape[2] - density_matrix_slice[ - :, :, :, n_ao_in_localized: - ] += density_matrix_with_rows_in_diffused.transpose(0, 1, 3, 2).conj() + density_matrix_slice[:, :, :, n_ao_in_localized:] += density_matrix_with_rows_in_diffused.transpose( + 0, 1, 3, 2 + ).conj() - coeff_sandwiched_density_matrix = cp.einsum( - "nkij,pi->nkpj", - density_matrix_slice, - pairs["coeff_in_localized"], - ) + n_sets, n_k_points = density_matrix_slice.shape[:2] - coeff_sandwiched_density_matrix = cp.einsum( - "nkpj, qj -> nkpq", - coeff_sandwiched_density_matrix, - pairs["concatenated_coeff"], - ) + density_matrix_slice = density_matrix_slice.reshape(-1, *pairs['contracted_shape']) - libgpbc.update_dxyz_dabc(pairs["dxyz_dabc"].ctypes) + coeff_sandwiched_density_matrix = contracted_to_primitive( + density_matrix_slice, + pairs['coeff_in_localized'], + pairs['concatenated_coeff'], + pairs['primitive_shape'], + ).reshape(n_sets, n_k_points, *pairs['primitive_shape']) - img_phase = image_phase_for_kpts(cell, pairs["neighboring_images"], kpts) + libgpbc.update_dxyz_dabc(pairs['dxyz_dabc'].ctypes) + + img_phase = image_phase_for_kpts(cell, pairs['neighboring_images'], kpts) evaluate_xc_gradient_wrapper( gradient, pairs, @@ -1146,7 +1454,7 @@ def get_pp(ni, kpts=None): mesh = ni.mesh # Compute the vpplocG as # -einsum('ij,ij->j', pseudo.get_vlocG(cell, Gv), cell.get_SI(Gv)) - vpplocG = multigrid_v1.eval_vpplocG(cell, mesh) + vpplocG = ni.vpplocG vpp = convert_xc_on_g_mesh_to_fock(ni, vpplocG, hermi=1, kpts=kpts)[0] t1 = log.timer_debug1("vpploc", *t0) @@ -1196,7 +1504,8 @@ def get_j_kpts(ni, dm_kpts, hermi=1, kpts=None, kpts_band=None): density = evaluate_density_on_g_mesh(ni, dm_kpts, kpts) Gv = pbc_tools._get_Gv(cell, mesh) - coulomb_kernel_on_g_mesh = pbc_tools.get_coulG(cell, Gv=Gv) + # coulomb_kernel_on_g_mesh = pbc_tools.get_coulG(cell, Gv=Gv) + coulomb_kernel_on_g_mesh = ni.coulG coulomb_on_g_mesh = cp.einsum( "ng, g -> ng", density[:, 0], coulomb_kernel_on_g_mesh @@ -1261,13 +1570,11 @@ def nr_rks(ni, cell, grids, xc_code, dm_kpts, relativity=0, hermi=1, density = evaluate_density_on_g_mesh(ni, dm_kpts, kpts, xc_type) rho_sf = density[0, 0] - Gv = pbc_tools._get_Gv(cell, mesh) - coulomb_kernel_on_g_mesh = pbc_tools.get_coulG(cell, Gv=Gv) - coulomb_on_g_mesh = rho_sf * coulomb_kernel_on_g_mesh - coulomb_energy = complex(rho_sf.conj().dot(coulomb_on_g_mesh).get()) - coulomb_energy = (0.5 / cell.vol) * coulomb_energy - log.debug("Multigrid Coulomb energy %s", coulomb_energy) - t0 = log.timer("coulomb", *t0) + coulomb_on_g_mesh = rho_sf * ni.coulG + coulomb_energy = 0.5 * rho_sf.conj().dot(coulomb_on_g_mesh).real + coulomb_energy /= cell.vol + log.debug('Multigrid Coulomb energy %s', coulomb_energy) + t0 = log.timer('coulomb', *t0) weight = cell.vol / ngrids density = ifft_in_place(density.reshape(-1, *mesh)).real.reshape(-1, ngrids) @@ -1299,17 +1606,19 @@ def nr_rks(ni, cell, grids, xc_code, dm_kpts, relativity=0, hermi=1, if xc_type == "LDA" or xc_type == 'HF': pass - elif xc_type == "GGA": - xc_for_fock = ( - xc_for_fock[0] - contract("gp, pg -> p", xc_for_fock[1:4], Gv) * 1j - ) + elif xc_type == 'GGA': + contract_iG_potential(xc_for_fock, cell) + xc_for_fock = xc_for_fock[0] xc_for_fock = xc_for_fock.reshape((-1, ngrids)) - elif xc_type == "MGGA": - xc_for_fock[0] -= contract("gp, pg -> p", xc_for_fock[1:4], Gv) * 1j - xc_for_fock = cp.concatenate([ - xc_for_fock[0].reshape((-1, ngrids)), - xc_for_fock[4].reshape((-1, ngrids)), - ], axis = 0) + elif xc_type == 'MGGA': + contract_iG_potential(xc_for_fock, cell) + xc_for_fock = cp.concatenate( + [ + xc_for_fock[0].reshape((-1, ngrids)), + xc_for_fock[4].reshape((-1, ngrids)), + ], + axis=0, + ) else: raise ValueError(f"Incorrect xc_type = {xc_type}") @@ -1370,11 +1679,8 @@ def nr_uks(ni, cell, grids, xc_code, dm_kpts, relativity=0, hermi=1, density = evaluate_density_on_g_mesh(ni, dm_kpts, kpts, xc_type) rho_sf = density[0, 0] + density[1, 0] - Gv = pbc_tools._get_Gv(cell, mesh) - coulomb_kernel_on_g_mesh = pbc_tools.get_coulG(cell, Gv=Gv) - coulomb_on_g_mesh = rho_sf * coulomb_kernel_on_g_mesh - coulomb_energy = rho_sf.conj().dot(coulomb_on_g_mesh).real - coulomb_energy = 0.5 * float(coulomb_energy.get()) + coulomb_on_g_mesh = rho_sf * ni.coulG + coulomb_energy = 0.5 * rho_sf.conj().dot(coulomb_on_g_mesh).real coulomb_energy /= cell.vol log.debug("Multigrid Coulomb energy %s", coulomb_energy) t0 = log.timer("coulomb", *t0) @@ -1410,17 +1716,20 @@ def nr_uks(ni, cell, grids, xc_code, dm_kpts, relativity=0, hermi=1, if xc_type == "LDA" or xc_type == 'HF': pass - elif xc_type == "GGA": - xc_for_fock = ( - xc_for_fock[:, 0] - contract("ngp, pg -> np", xc_for_fock[:, 1:4], Gv) * 1j + elif xc_type == 'GGA': + for i in xc_for_fock: + contract_iG_potential(i, cell) + xc_for_fock = xc_for_fock[:,0].reshape((nset, -1, ngrids)) + elif xc_type == 'MGGA': + for i in xc_for_fock: + contract_iG_potential(i, cell) + xc_for_fock = cp.concatenate( + [ + xc_for_fock[:, 0].reshape((nset, -1, ngrids)), + xc_for_fock[:, 4].reshape((nset, -1, ngrids)), + ], + axis=1, ) - xc_for_fock = xc_for_fock.reshape((nset, -1, ngrids)) - elif xc_type == "MGGA": - xc_for_fock[:, 0] -= contract("ngp, pg -> np", xc_for_fock[:, 1:4], Gv) * 1j - xc_for_fock = cp.concatenate([ - xc_for_fock[:, 0].reshape((nset, -1, ngrids)), - xc_for_fock[:, 4].reshape((nset, -1, ngrids)), - ], axis = 1) else: raise ValueError(f"Incorrect xc_type = {xc_type}") @@ -1482,10 +1791,7 @@ def get_veff_ip1( density = evaluate_density_on_g_mesh(ni, dm_kpts, kpts, xc_type) Gv = pbc_tools._get_Gv(cell, mesh) - coulomb_kernel_on_g_mesh = pbc_tools.get_coulG(cell, Gv=Gv) - coulomb_on_g_mesh = cp.einsum( - "ng, g -> g", density[:, 0], coulomb_kernel_on_g_mesh - ) + coulomb_on_g_mesh = cp.einsum('ng, g -> g', density[:, 0], ni.coulG) weight = cell.vol / ngrids @@ -1513,17 +1819,21 @@ def get_veff_ip1( if xc_type == "LDA" or xc_type == 'HF': pass - elif xc_type == "GGA": - xc_for_fock = ( - xc_for_fock[:, 0] - contract("ngp, pg -> np", xc_for_fock[:, 1:4], Gv) * 1j - ) + elif xc_type == 'GGA': + for i in xc_for_fock: + contract_iG_potential(i, cell) + xc_for_fock = xc_for_fock[:, 0] xc_for_fock = xc_for_fock.reshape((nset, -1, ngrids)) - elif xc_type == "MGGA": - xc_for_fock[:, 0] -= contract("ngp, pg -> np", xc_for_fock[:, 1:4], Gv) * 1j - xc_for_fock = cp.concatenate([ - xc_for_fock[:, 0].reshape((nset, -1, ngrids)), - xc_for_fock[:, 4].reshape((nset, -1, ngrids)), - ], axis = 1) + elif xc_type == 'MGGA': + for i in xc_for_fock: + contract_iG_potential(i, cell) + xc_for_fock = cp.concatenate( + [ + xc_for_fock[:, 0].reshape((nset, -1, ngrids)), + xc_for_fock[:, 4].reshape((nset, -1, ngrids)), + ], + axis=1, + ) else: raise ValueError(f"Incorrect xc_type = {xc_type}") @@ -1531,10 +1841,7 @@ def get_veff_ip1( xc_for_fock[:, 0] += coulomb_on_g_mesh if with_pseudo_vloc_orbital_derivative: - if cell._pseudo: - xc_for_fock[:, 0] += multigrid_v1.eval_vpplocG(cell, mesh) - else: - xc_for_fock[:, 0] += multigrid_v1.eval_nucG(cell, mesh) + xc_for_fock[:, 0] += ni.vpplocG veff_gradient = convert_xc_on_g_mesh_to_fock_gradient( ni, xc_for_fock, dm_kpts, hermi, kpts, with_tau = (xc_type == "MGGA") @@ -1550,6 +1857,14 @@ def __init__(self, cell): self.mesh = cell.mesh self.tasks = None self.sorted_gaussian_pairs = None + Gv = pbc_tools._get_Gv(cell, cell.mesh) + self.coulG = pbc_tools.get_coulG(cell, Gv=Gv) + if cell._pseudo: + self.vpplocG = multigrid_v1.eval_vpplocG(cell, cell.mesh) + else: + self.vpplocG = multigrid_v1.eval_nucG(cell, cell.mesh) + + self.build() build = sort_gaussian_pairs