From 68388a84ed5f0d6d994fdff624a0fa2309637a39 Mon Sep 17 00:00:00 2001 From: Qiming Sun Date: Thu, 28 May 2026 14:02:01 -0700 Subject: [PATCH] Handle wb97x-d4 --- gpu4pyscf/scf/dispersion.py | 54 +++++++++++++++++++++++++------------ 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/gpu4pyscf/scf/dispersion.py b/gpu4pyscf/scf/dispersion.py index 0a0019a3e..df09884a0 100644 --- a/gpu4pyscf/scf/dispersion.py +++ b/gpu4pyscf/scf/dispersion.py @@ -1,4 +1,4 @@ -# Copyright 2021-2024 The PySCF Developers. All Rights Reserved. +# Copyright 2021-2026 The PySCF Developers. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,12 +20,17 @@ dispersion correction for HF and DFT ''' +import warnings from functools import lru_cache from pyscf.lib import logger -from pyscf import scf +from gpu4pyscf import scf +from pyscf import __config__ + +DFTD4_RECOMMENDATIONS = getattr(__config__, 'DFTD4_RECOMMENDATIONS', True) # supported dispersion corrections DISP_VERSIONS = ['d3bj', 'd3zero', 'd3bjm', 'd3zerom', 'd3op', 'd4'] +# XC names for dftd3 and dftd4 inputs XC_MAP = {'wb97m-d3bj': 'wb97m', 'b97m-d3bj': 'b97m', 'wb97x-d3bj': 'wb97x', @@ -47,6 +52,10 @@ 'wb97x-d3bj': ('wb97x-v', False, 'd3bj'), 'wb97x-3c': ('wb97x-v', False, 'd4:wb97x-3c'), } +if DFTD4_RECOMMENDATIONS: + _white_list['b97m-d4'] = ('b97m_v', False, 'd4:b97m') + _white_list['wb97m-d4'] = ('wb97m_v', False, 'd4:wb97m') + _white_list['wb97x-d4'] = ('wb97x_v', False, 'd4:wb97x') # These xc functionals are not supported yet _black_list = { @@ -72,14 +81,25 @@ def parse_dft(xc_code): if method_lower in _black_list: raise NotImplementedError(f'{method_lower} is not supported yet.') - if method_lower in _white_list: - return _white_list[method_lower] - if method_lower.endswith('-3c'): if method_lower == "wb97x-3c": return _white_list[method_lower] raise NotImplementedError('Only wb97x-3c is supported for now. Other 3c methods are not supported yet.') + if method_lower in _white_list: + if '-d4' in method_lower: + xc, nlc, disp = _white_list[method_lower] + if disp == 'd4:wb97x': + disp = disp + '-2008' + warnings.warn(f''' +{method_lower} now follows the DFT-D4 recommendations: {xc} with D4 instead of vv10. +To reproduce the previous PySCF behavior (v2.13 and earlier versions), explicitly specify + + mf.xc = '{xc[:-2]}' + mf.disp = '{disp}' +''', FutureWarning, stacklevel=2) + return _white_list[method_lower] + if '-d3' in method_lower or '-d4' in method_lower: xc, disp = method_lower.split('-') else: @@ -120,11 +140,11 @@ def parse_disp(dft_method=None, disp=None): >>> parse_disp(None, 'd4:wb97x-3c') ('wb97x-3c', 'd4', True) ''' - + # If anything not specified, return None if dft_method is None and disp is None: return None, None, False - + def process_3body(disp_version): if not disp_version: return disp_version, False @@ -141,11 +161,11 @@ def process_3body(disp_version): if dft_method is not None: dft_lower = dft_method.lower() - xc, _, disp_from_dft = parse_dft(dft_lower) + xc, nlc, disp_from_dft = parse_dft(dft_lower) if xc in XC_MAP: xc = XC_MAP[xc] - # Use disp if specfied + # Use disp if specified # returned method will be the latter part of disp if disp is a string with colon, otherwise, use xc if disp is not None: if ":" in disp: @@ -157,16 +177,16 @@ def process_3body(disp_version): return xc, disp, with_3body else: raise ValueError(f"the method used in dispersion {disp} is not specified.") - + # otherwise, use disp_from_dft if disp_from_dft is None: return None, None, False - + if ":" in disp_from_dft: disp_version, method = disp_from_dft.split(':') disp_version, with_3body = process_3body(disp_version) return method, disp_version, with_3body - + disp_from_dft, with_3body = process_3body(disp_from_dft) return xc, disp_from_dft, with_3body @@ -180,20 +200,20 @@ def check_disp(mf, disp=None): Args: mf (scf.hf.SCF): The SCF object (HF or DFT). - disp (str or bool, optional): Dispersion version to check. + disp (str or bool, optional): Dispersion version to check. If None, uses `mf.disp`. If False, returns False immediately. Returns: - bool: True if dispersion is enabled and supported. + bool: True if dispersion is enabled and supported. False if dispersion is disabled (disp=False) or not specified/implied. - + Raises: ValueError: If the dispersion version is not supported. ''' if disp is None: disp = getattr(mf, 'disp', None) - if disp is False: + if disp is False or disp == 0: return False # To prevent mf.do_disp() triggering the SCF.__getattr__ method, do not use @@ -233,7 +253,7 @@ def get_dispersion(mf, disp=None, with_3body=None, verbose=None): The dispersion correction energy. Note: - Priority of `disp` and `with_atm`: + Priority of `disp` and `with_3body`: 1. Function arguments (disp, with_3body) 2. mf.disp (if available) 3. mf.xc (parsed from the functional name)