Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 37 additions & 17 deletions gpu4pyscf/scf/dispersion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021-2024 The PySCF Developers. All Rights Reserved.
# Copyright 2021-2026 The PySCF Developers. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -20,12 +20,17 @@
dispersion correction for HF and DFT
'''

import warnings
from functools import lru_cache
from pyscf.lib import logger
from pyscf import scf
from gpu4pyscf import scf
from pyscf import __config__

DFTD4_RECOMMENDATIONS = getattr(__config__, 'DFTD4_RECOMMENDATIONS', True)

# supported dispersion corrections
DISP_VERSIONS = ['d3bj', 'd3zero', 'd3bjm', 'd3zerom', 'd3op', 'd4']
# XC names for dftd3 and dftd4 inputs
XC_MAP = {'wb97m-d3bj': 'wb97m',
'b97m-d3bj': 'b97m',
'wb97x-d3bj': 'wb97x',
Expand All @@ -47,6 +52,10 @@
'wb97x-d3bj': ('wb97x-v', False, 'd3bj'),
'wb97x-3c': ('wb97x-v', False, 'd4:wb97x-3c'),
}
if DFTD4_RECOMMENDATIONS:
_white_list['b97m-d4'] = ('b97m_v', False, 'd4:b97m')
_white_list['wb97m-d4'] = ('wb97m_v', False, 'd4:wb97m')
_white_list['wb97x-d4'] = ('wb97x_v', False, 'd4:wb97x')

# These xc functionals are not supported yet
_black_list = {
Expand All @@ -72,14 +81,25 @@ def parse_dft(xc_code):
if method_lower in _black_list:
raise NotImplementedError(f'{method_lower} is not supported yet.')

if method_lower in _white_list:
return _white_list[method_lower]

if method_lower.endswith('-3c'):
if method_lower == "wb97x-3c":
return _white_list[method_lower]
raise NotImplementedError('Only wb97x-3c is supported for now. Other 3c methods are not supported yet.')

if method_lower in _white_list:
if '-d4' in method_lower:
xc, nlc, disp = _white_list[method_lower]
if disp == 'd4:wb97x':
disp = disp + '-2008'
warnings.warn(f'''
{method_lower} now follows the DFT-D4 recommendations: {xc} with D4 instead of vv10.
To reproduce the previous PySCF behavior (v2.13 and earlier versions), explicitly specify

mf.xc = '{xc[:-2]}'
mf.disp = '{disp}'
''', FutureWarning, stacklevel=2)
return _white_list[method_lower]

if '-d3' in method_lower or '-d4' in method_lower:
xc, disp = method_lower.split('-')
else:
Expand Down Expand Up @@ -120,11 +140,11 @@ def parse_disp(dft_method=None, disp=None):
>>> parse_disp(None, 'd4:wb97x-3c')
('wb97x-3c', 'd4', True)
'''

# If anything not specified, return None
if dft_method is None and disp is None:
return None, None, False

def process_3body(disp_version):
if not disp_version:
return disp_version, False
Expand All @@ -141,11 +161,11 @@ def process_3body(disp_version):

if dft_method is not None:
dft_lower = dft_method.lower()
xc, _, disp_from_dft = parse_dft(dft_lower)
xc, nlc, disp_from_dft = parse_dft(dft_lower)
if xc in XC_MAP:
xc = XC_MAP[xc]

# Use disp if specfied
# Use disp if specified
# returned method will be the latter part of disp if disp is a string with colon, otherwise, use xc
if disp is not None:
if ":" in disp:
Expand All @@ -157,16 +177,16 @@ def process_3body(disp_version):
return xc, disp, with_3body
else:
raise ValueError(f"the method used in dispersion {disp} is not specified.")

# otherwise, use disp_from_dft
if disp_from_dft is None:
return None, None, False

if ":" in disp_from_dft:
disp_version, method = disp_from_dft.split(':')
disp_version, with_3body = process_3body(disp_version)
return method, disp_version, with_3body

disp_from_dft, with_3body = process_3body(disp_from_dft)
return xc, disp_from_dft, with_3body

Expand All @@ -180,20 +200,20 @@ def check_disp(mf, disp=None):

Args:
mf (scf.hf.SCF): The SCF object (HF or DFT).
disp (str or bool, optional): Dispersion version to check.
disp (str or bool, optional): Dispersion version to check.
If None, uses `mf.disp`.
If False, returns False immediately.

Returns:
bool: True if dispersion is enabled and supported.
bool: True if dispersion is enabled and supported.
False if dispersion is disabled (disp=False) or not specified/implied.

Raises:
ValueError: If the dispersion version is not supported.
'''
if disp is None:
disp = getattr(mf, 'disp', None)
if disp is False:
if disp is False or disp == 0:
return False

# To prevent mf.do_disp() triggering the SCF.__getattr__ method, do not use
Expand Down Expand Up @@ -233,7 +253,7 @@ def get_dispersion(mf, disp=None, with_3body=None, verbose=None):
The dispersion correction energy.

Note:
Priority of `disp` and `with_atm`:
Priority of `disp` and `with_3body`:
1. Function arguments (disp, with_3body)
2. mf.disp (if available)
3. mf.xc (parsed from the functional name)
Expand Down
Loading