diff --git a/functions/__init__.py b/functions/__init__.py index c030f55..2de81cd 100644 --- a/functions/__init__.py +++ b/functions/__init__.py @@ -10,11 +10,10 @@ from .pyrosetta_utils import * from .colabdesign_utils import * -from .biopython_utils import * +from .biotite_utils import * from .generic_utils import * # suppress warnings #os.environ["SLURM_STEP_NODELIST"] = os.environ["SLURM_NODELIST"] warnings.simplefilter(action='ignore', category=FutureWarning) warnings.simplefilter(action='ignore', category=DeprecationWarning) -warnings.simplefilter(action='ignore', category=BiopythonWarning) \ No newline at end of file diff --git a/functions/biopython_utils.py b/functions/biopython_utils.py deleted file mode 100644 index a0a4bd4..0000000 --- a/functions/biopython_utils.py +++ /dev/null @@ -1,236 +0,0 @@ -#################################### -################ BioPython functions -#################################### -### Import dependencies -import os -import math -import numpy as np -from collections import defaultdict -from scipy.spatial import cKDTree -from Bio import BiopythonWarning -from Bio.PDB import PDBParser, DSSP, Selection, Polypeptide, PDBIO, Select, Chain, Superimposer -from Bio.SeqUtils.ProtParam import ProteinAnalysis -from Bio.PDB.Selection import unfold_entities -from Bio.PDB.Polypeptide import is_aa - -# analyze sequence composition of design -def validate_design_sequence(sequence, num_clashes, advanced_settings): - note_array = [] - - # Check if protein contains clashes after relaxation - if num_clashes > 0: - note_array.append('Relaxed structure contains clashes.') - - # Check if the sequence contains disallowed amino acids - if advanced_settings["omit_AAs"]: - restricted_AAs = advanced_settings["omit_AAs"].split(',') - for restricted_AA in restricted_AAs: - if restricted_AA in sequence: - note_array.append('Contains: '+restricted_AA+'!') - - # Analyze the protein - analysis = ProteinAnalysis(sequence) - - # Calculate the reduced extinction coefficient per 1% solution - extinction_coefficient_reduced = analysis.molar_extinction_coefficient()[0] - molecular_weight = round(analysis.molecular_weight() / 1000, 2) - extinction_coefficient_reduced_1 = round(extinction_coefficient_reduced / molecular_weight * 0.01, 2) - - # Check if the absorption is high enough - if extinction_coefficient_reduced_1 <= 2: - note_array.append(f'Absorption value is {extinction_coefficient_reduced_1}, consider adding tryptophane to design.') - - # Join the notes into a single string - notes = ' '.join(note_array) - - return notes - -# temporary function, calculate RMSD of input PDB and trajectory target -def target_pdb_rmsd(trajectory_pdb, starting_pdb, chain_ids_string): - # Parse the PDB files - parser = PDBParser(QUIET=True) - structure_trajectory = parser.get_structure('trajectory', trajectory_pdb) - structure_starting = parser.get_structure('starting', starting_pdb) - - # Extract chain A from trajectory_pdb - chain_trajectory = structure_trajectory[0]['A'] - - # Extract the specified chains from starting_pdb - chain_ids = chain_ids_string.split(',') - residues_starting = [] - for chain_id in chain_ids: - chain_id = chain_id.strip() - chain = structure_starting[0][chain_id] - for residue in chain: - if is_aa(residue, standard=True): - residues_starting.append(residue) - - # Extract residues from chain A in trajectory_pdb - residues_trajectory = [residue for residue in chain_trajectory if is_aa(residue, standard=True)] - - # Ensure that both structures have the same number of residues - min_length = min(len(residues_starting), len(residues_trajectory)) - residues_starting = residues_starting[:min_length] - residues_trajectory = residues_trajectory[:min_length] - - # Collect CA atoms from the two sets of residues - atoms_starting = [residue['CA'] for residue in residues_starting if 'CA' in residue] - atoms_trajectory = [residue['CA'] for residue in residues_trajectory if 'CA' in residue] - - # Calculate RMSD using structural alignment - sup = Superimposer() - sup.set_atoms(atoms_starting, atoms_trajectory) - rmsd = sup.rms - - return round(rmsd, 2) - -# detect C alpha clashes for deformed trajectories -def calculate_clash_score(pdb_file, threshold=2.4, only_ca=False): - parser = PDBParser(QUIET=True) - structure = parser.get_structure('protein', pdb_file) - - atoms = [] - atom_info = [] # Detailed atom info for debugging and processing - - for model in structure: - for chain in model: - for residue in chain: - for atom in residue: - if atom.element == 'H': # Skip hydrogen atoms - continue - if only_ca and atom.get_name() != 'CA': - continue - atoms.append(atom.coord) - atom_info.append((chain.id, residue.id[1], atom.get_name(), atom.coord)) - - tree = cKDTree(atoms) - pairs = tree.query_pairs(threshold) - - valid_pairs = set() - for (i, j) in pairs: - chain_i, res_i, name_i, coord_i = atom_info[i] - chain_j, res_j, name_j, coord_j = atom_info[j] - - # Exclude clashes within the same residue - if chain_i == chain_j and res_i == res_j: - continue - - # Exclude directly sequential residues in the same chain for all atoms - if chain_i == chain_j and abs(res_i - res_j) == 1: - continue - - # If calculating sidechain clashes, only consider clashes between different chains - if not only_ca and chain_i == chain_j: - continue - - valid_pairs.add((i, j)) - - return len(valid_pairs) - -three_to_one_map = { - 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F', - 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L', - 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R', - 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y' -} - -# identify interacting residues at the binder interface -def hotspot_residues(trajectory_pdb, binder_chain="B", atom_distance_cutoff=4.0): - # Parse the PDB file - parser = PDBParser(QUIET=True) - structure = parser.get_structure("complex", trajectory_pdb) - - # Get the specified chain - binder_atoms = Selection.unfold_entities(structure[0][binder_chain], 'A') - binder_coords = np.array([atom.coord for atom in binder_atoms]) - - # Get atoms and coords for the target chain - target_atoms = Selection.unfold_entities(structure[0]['A'], 'A') - target_coords = np.array([atom.coord for atom in target_atoms]) - - # Build KD trees for both chains - binder_tree = cKDTree(binder_coords) - target_tree = cKDTree(target_coords) - - # Prepare to collect interacting residues - interacting_residues = {} - - # Query the tree for pairs of atoms within the distance cutoff - pairs = binder_tree.query_ball_tree(target_tree, atom_distance_cutoff) - - # Process each binder atom's interactions - for binder_idx, close_indices in enumerate(pairs): - binder_residue = binder_atoms[binder_idx].get_parent() - binder_resname = binder_residue.get_resname() - - # Convert three-letter code to single-letter code using the manual dictionary - if binder_resname in three_to_one_map: - aa_single_letter = three_to_one_map[binder_resname] - for close_idx in close_indices: - target_residue = target_atoms[close_idx].get_parent() - interacting_residues[binder_residue.id[1]] = aa_single_letter - - return interacting_residues - -# calculate secondary structure percentage of design -def calc_ss_percentage(pdb_file, advanced_settings, chain_id="B", atom_distance_cutoff=4.0): - # Parse the structure - parser = PDBParser(QUIET=True) - structure = parser.get_structure('protein', pdb_file) - model = structure[0] # Consider only the first model in the structure - - # Calculate DSSP for the model - dssp = DSSP(model, pdb_file, dssp=advanced_settings["dssp_path"]) - - # Prepare to count residues - ss_counts = defaultdict(int) - ss_interface_counts = defaultdict(int) - plddts_interface = [] - plddts_ss = [] - - # Get chain and interacting residues once - chain = model[chain_id] - interacting_residues = set(hotspot_residues(pdb_file, chain_id, atom_distance_cutoff).keys()) - - for residue in chain: - residue_id = residue.id[1] - if (chain_id, residue_id) in dssp: - ss = dssp[(chain_id, residue_id)][2] # Get the secondary structure - ss_type = 'loop' - if ss in ['H', 'G', 'I']: - ss_type = 'helix' - elif ss == 'E': - ss_type = 'sheet' - - ss_counts[ss_type] += 1 - - if ss_type != 'loop': - # calculate secondary structure normalised pLDDT - avg_plddt_ss = sum(atom.bfactor for atom in residue) / len(residue) - plddts_ss.append(avg_plddt_ss) - - if residue_id in interacting_residues: - ss_interface_counts[ss_type] += 1 - - # calculate interface pLDDT - avg_plddt_residue = sum(atom.bfactor for atom in residue) / len(residue) - plddts_interface.append(avg_plddt_residue) - - # Calculate percentages - total_residues = sum(ss_counts.values()) - total_interface_residues = sum(ss_interface_counts.values()) - - percentages = calculate_percentages(total_residues, ss_counts['helix'], ss_counts['sheet']) - interface_percentages = calculate_percentages(total_interface_residues, ss_interface_counts['helix'], ss_interface_counts['sheet']) - - i_plddt = round(sum(plddts_interface) / len(plddts_interface) / 100, 2) if plddts_interface else 0 - ss_plddt = round(sum(plddts_ss) / len(plddts_ss) / 100, 2) if plddts_ss else 0 - - return (*percentages, *interface_percentages, i_plddt, ss_plddt) - -def calculate_percentages(total, helix, sheet): - helix_percentage = round((helix / total) * 100,2) if total > 0 else 0 - sheet_percentage = round((sheet / total) * 100,2) if total > 0 else 0 - loop_percentage = round(((total - helix - sheet) / total) * 100,2) if total > 0 else 0 - - return helix_percentage, sheet_percentage, loop_percentage \ No newline at end of file diff --git a/functions/biotite_utils.py b/functions/biotite_utils.py new file mode 100644 index 0000000..5b9efbc --- /dev/null +++ b/functions/biotite_utils.py @@ -0,0 +1,290 @@ +#################################### +################ Biotite functions +#################################### +### Import dependencies +from biotite.application.application import AppState, requires_state +from biotite.application.localapp import get_version, cleanup_tempfile, LocalApp +from scipy.spatial import KDTree +from subprocess import SubprocessError +from tempfile import NamedTemporaryFile +import biotite.application.dssp as b_dssp +import biotite.sequence as b_sequence +import biotite.structure as b_structure +import fastpdb +import numpy as np + +# analyze sequence composition of design +def validate_design_sequence(sequence, num_clashes, advanced_settings): + note_array = [] + bseq = b_sequence.ProteinSequence(sequence) + + # Check if protein contains clashes after relaxation + if num_clashes > 0: + note_array.append('Relaxed structure contains clashes.') + + # Check if the sequence contains disallowed amino acids + if advanced_settings["omit_AAs"]: + restricted_AAs = advanced_settings["omit_AAs"].split(',') + for restricted_AA in restricted_AAs: + if restricted_AA in sequence: + note_array.append('Contains: '+restricted_AA+'!') + + # Analyze the protein + + # Calculate the reduced extinction coefficient per 1% solution + num_aa = bseq.get_symbol_frequency() + extinction_coefficient_reduced = num_aa["W"] * 5500 + num_aa["Y"] * 1490 + molecular_weight = np.round(bseq.get_molecular_weight() / 1000, 2) + extinction_coefficient_reduced_1 = np.round(extinction_coefficient_reduced / molecular_weight * 0.01, 2) + + # Check if the absorption is high enough + if extinction_coefficient_reduced_1 <= 2: + note_array.append(f'Absorption value is {extinction_coefficient_reduced_1}, consider adding tryptophane to design.') + + # Join the notes into a single string + notes = ' '.join(note_array) + + return notes + +# temporary function, calculate RMSD of input PDB and trajectory target +def target_pdb_rmsd(trajectory_pdb, starting_pdb, chain_ids_string): + # Parse the PDB files + file_trajectory = fastpdb.PDBFile.read(trajectory_pdb) + file_starting = fastpdb.PDBFile.read(starting_pdb) + + aa_trajectory = file_trajectory.get_structure(model=1) + aa_starting = file_starting.get_structure(model=1) + + # Extract CA atoms from starting_pdb + chain_ids = [chain_id.strip() for chain_id in chain_ids_string.split(',')] + + aa_starting = aa_starting[(np.isin(aa_starting.chain_id, chain_ids)) + & (aa_starting.atom_name == "CA") + & (aa_starting.element == "C") + & (b_structure.filter_canonical_amino_acids(aa_starting))] + + # Extract CA atoms from chain A in trajectory_pdb + aa_trajectory = aa_trajectory[(aa_trajectory.chain_id == "A") + & (aa_trajectory.atom_name == "CA") + & (aa_trajectory.element == "C") + & (b_structure.filter_canonical_amino_acids(aa_trajectory))] + + # Ensure that both structures have the same number of residues + min_length = min(aa_starting.array_length(), aa_trajectory.array_length()) + aa_starting = aa_starting[:min_length] + aa_trajectory = aa_trajectory[:min_length] + + # Calculate RMSD using structural alignment + superimposed, _ = b_structure.superimpose(aa_starting, aa_trajectory) + rmsd = b_structure.rmsd(aa_starting, superimposed) + + return np.round(rmsd, 2) + +# detect C alpha clashes for deformed trajectories +def calculate_clash_score(pdb_file, threshold=2.4, only_ca=False): + file_structure = fastpdb.PDBFile.read(pdb_file) + aa_structure = file_structure.get_structure(model=1) + + atoms = aa_structure[aa_structure.element != "H"] + if only_ca: + atoms = atoms[(atoms.atom_name == "CA") & (atoms.element == "C")] + + tree = KDTree(atoms.coord) + pairs_set = tree.query_pairs(threshold) + if not pairs_set: + return 0 + + pairs = np.array(list(pairs_set)) + + atoms_i = atoms[pairs[:, 0]] + atoms_j = atoms[pairs[:, 1]] + + # If calculating sidechain clashes, only consider clashes between different chains + if not only_ca: + valid_mask = atoms_i.chain_id != atoms_j.chain_id + else: + # Exclude clashes within the same residue or sequential residues in the + # same chain for all atoms. So only clashes that are in different + # chains and more than 1 residue apart will be a valid clash + valid_mask = ((atoms_i.chain_id != atoms_j.chain_id) + | (np.abs(atoms_i.res_id - atoms_j.res_id) > 1)) + + + return valid_mask.sum() + +three_to_one_map = { + 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F', + 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L', + 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R', + 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y' +} + +# identify interacting residues at the binder interface +def hotspot_residues(trajectory_pdb, binder_chain="B", atom_distance_cutoff=4.0): + # Parse the PDB file + file_structure = fastpdb.PDBFile.read(trajectory_pdb) + aa_structure = file_structure.get_structure(model=1) + + # Get the specified chain + aa_binder = aa_structure[aa_structure.chain_id == binder_chain] + + # Get atoms and coords for the target chain + aa_target = aa_structure[aa_structure.chain_id == "A"] + + # Build KD trees for both chains + binder_tree = KDTree(aa_binder.coord) + target_tree = KDTree(aa_target.coord) + + # Prepare to collect interacting residues + interacting_residues = {} + + # Query the tree for pairs of atoms within the distance cutoff + pairs = binder_tree.query_ball_tree(target_tree, atom_distance_cutoff) + + is_close = np.asarray([len(idx) > 0 for idx in pairs], dtype=bool) + + if is_close.size == 0: + return {} + + # Filter residues with the manual dictionary + aa_binder = aa_binder[np.isin(aa_binder.res_name, list(three_to_one_map.keys()))] + + binder_resid = aa_binder[is_close].res_id + binder_resname = aa_binder[is_close].res_name + + interacting_residues = {int(binder_resid[i]): three_to_one_map[binder_resname[i]] for i in range(binder_resid.size)} + + return interacting_residues + +class DsspFromFile(b_dssp.DsspApp): + """ + Implementation of biotite.application.dssp.DsspApp that processes input + files directly without the need to temporarily save the structure and + returns the secondary structure with resid and chaind. + """ + def __init__(self, in_file, bin_path="mkdssp"): + LocalApp.__init__(self, bin_path) + try: + # The parameters have changed in version 4 + self._new_cli = get_version(bin_path)[0] >= 4 + except SubprocessError: + # In older versions, the no version is returned with `--version` + # -> a SubprocessError is raised + self._new_cli = False + + self._in_file = in_file + self._out_file = NamedTemporaryFile("r", suffix=".dssp", delete=False) + + def run(self): + if self._new_cli: + self.set_arguments([self._in_file, self._out_file.name]) + else: + self.set_arguments(["-i", self._in_file, "-o", self._out_file.name]) + LocalApp.run(self) + + def evaluate(self): + LocalApp.evaluate(self) + lines = self._out_file.read().split("\n") + # Index where SSE records start + sse_start = None + for i, line in enumerate(lines): + if line.startswith(" # RESIDUE AA STRUCTURE"): + sse_start = i + 1 + if sse_start is None: + raise ValueError("DSSP file does not contain SSE records") + # Remove "!" for missing residues + lines = [ + line for line in lines[sse_start:] if len(line) != 0 and line[13] != "!" + ] + self._sse = np.zeros(len(lines), dtype="U1") + self._sse_chainids = np.zeros(len(lines), dtype="U1") + self._sse_resids = np.zeros(len(lines), dtype=int) + # Parse file for SSE letters + for i, line in enumerate(lines): + self._sse[i] = line[16] + self._sse_chainids[i] = line[11] + self._sse_resids[i] = int(line[5:10]) + self._sse[self._sse == " "] = "C" + + def clean_up(self): + LocalApp.clean_up(self) + cleanup_tempfile(self._out_file) + + @staticmethod + def annotate_sse(in_file, bin_path="mkdssp"): + app = DsspFromFile(in_file, bin_path) + app.start() + app.join() + return app.get_sse() + + @requires_state(AppState.JOINED) + def get_sse(self): + return self._sse, self._sse_chainids, self._sse_resids + + + +# calculate secondary structure percentage of design +def calc_ss_percentage(pdb_file, advanced_settings, chain_id="B", atom_distance_cutoff=4.0): + # Parse the structure + file_structure = fastpdb.PDBFile.read(pdb_file) + aa_structure = file_structure.get_structure(model=1, extra_fields=["b_factor"]) # Consider only the first model in the structure + aa_chain = aa_structure[aa_structure.chain_id == chain_id] + + # Calculate DSSP for the model + dssp_ss, dssp_chainids, dssp_resids = DsspFromFile.annotate_sse(pdb_file, bin_path=advanced_settings["dssp_path"]) + + chain_ss = dssp_ss[dssp_chainids == chain_id] + chain_ss_resids = dssp_resids[dssp_chainids == chain_id] + + ss_counts = {} + ss_interface_counts = {} + + # Count secondary structures + ss_types, ss_type_counts = np.unique(chain_ss, return_counts=True) + ss_counts["helix"] = ss_type_counts[np.isin(ss_types, ['H', 'G', 'I'])].sum() + ss_counts["sheet"] = ss_type_counts[ss_types == 'E'].sum() + ss_counts["loop"] = chain_ss.size - (ss_counts["helix"] + ss_counts["sheet"]) + + + interacting_residues = list(hotspot_residues(pdb_file, chain_id, atom_distance_cutoff).keys()) + interface_ss = chain_ss[np.isin(chain_ss_resids, interacting_residues)] + interface_ss_resids = chain_ss_resids[np.isin(chain_ss_resids, interacting_residues)] + + # Calculate nonloop pLDDTs + chain_nonloop_resids = chain_ss_resids[np.isin(chain_ss, ['H', 'G', 'I', 'E'])] + aa_chain_nonloop = aa_chain[np.isin(aa_chain.res_id, chain_nonloop_resids)] + + plddts_ss = b_structure.apply_residue_wise(aa_chain_nonloop, aa_chain_nonloop.b_factor, np.mean) + + # Count interface secondary structures + ss_interface_types, ss_interface_type_counts = np.unique(interface_ss, return_counts=True) + ss_interface_counts["helix"] = ss_interface_type_counts[np.isin(ss_interface_types, ['H', 'G', 'I'])].sum() + ss_interface_counts["sheet"] = ss_interface_type_counts[ss_interface_types == 'E'].sum() + ss_interface_counts["loop"] = interface_ss.size - (ss_counts["helix"] + ss_counts["sheet"]) + + # Calculate interface pLDDT (only use residues returned from DSSP) + aa_interface_chain = aa_chain[np.isin(aa_chain.res_id, interface_ss_resids)] + plddts_interface = b_structure.apply_residue_wise(aa_interface_chain, aa_interface_chain.b_factor, np.mean) + + # Calculate percentages + total_residues = chain_ss.size + total_interface_residues = interface_ss.size + + percentages = calculate_percentages(total_residues, ss_counts['helix'], ss_counts['sheet']) + interface_percentages = calculate_percentages(total_interface_residues, ss_interface_counts['helix'], ss_interface_counts['sheet']) + + if "b_factor" not in aa_structure.get_annotation_categories(): + i_plddt = 0 + ss_plddt = 0 + else: + i_plddt = np.round(np.mean(plddts_interface) / 100, 2) + ss_plddt = np.round(np.mean(plddts_ss) / 100, 2) + + return (*percentages, *interface_percentages, i_plddt, ss_plddt) + +def calculate_percentages(total, helix, sheet): + helix_percentage = round((helix / total) * 100,2) if total > 0 else 0 + sheet_percentage = round((sheet / total) * 100,2) if total > 0 else 0 + loop_percentage = round(((total - helix - sheet) / total) * 100,2) if total > 0 else 0 + + return helix_percentage, sheet_percentage, loop_percentage diff --git a/functions/colabdesign_utils.py b/functions/colabdesign_utils.py index 4bb1db5..307f108 100644 --- a/functions/colabdesign_utils.py +++ b/functions/colabdesign_utils.py @@ -13,7 +13,7 @@ from colabdesign.af.alphafold.common import residue_constants from colabdesign.af.loss import get_ptm, mask_loss, get_dgram_bins, _get_con_loss from colabdesign.shared.utils import copy_dict -from .biopython_utils import hotspot_residues, calculate_clash_score, calc_ss_percentage, calculate_percentages +from .biotite_utils import hotspot_residues, calculate_clash_score, calc_ss_percentage, calculate_percentages from .pyrosetta_utils import pr_relax, align_pdbs from .generic_utils import update_failures diff --git a/functions/pyrosetta_utils.py b/functions/pyrosetta_utils.py index 4d6c8a2..9df96b2 100644 --- a/functions/pyrosetta_utils.py +++ b/functions/pyrosetta_utils.py @@ -14,7 +14,7 @@ from pyrosetta.rosetta.core.io import pose_from_pose from pyrosetta.rosetta.protocols.rosetta_scripts import XmlObjects from .generic_utils import clean_pdb -from .biopython_utils import hotspot_residues +from .biotite_utils import hotspot_residues # Rosetta interface scores def score_interface(pdb_file, binder_chain="B"): @@ -240,4 +240,4 @@ def pr_relax(pdb_file, relaxed_pdb_path): # output relaxed and aligned PDB pose.dump_pdb(relaxed_pdb_path) - clean_pdb(relaxed_pdb_path) \ No newline at end of file + clean_pdb(relaxed_pdb_path) diff --git a/install_bindcraft.sh b/install_bindcraft.sh index 2fffc69..0f15425 100644 --- a/install_bindcraft.sh +++ b/install_bindcraft.sh @@ -64,14 +64,14 @@ echo -e "BindCraft environment activated at ${CONDA_BASE}/envs/BindCraft" echo -e "Instaling conda requirements\n" if [ -n "$cuda" ]; then CONDA_OVERRIDE_CUDA="$cuda" $pkg_manager install \ - pip pandas matplotlib 'numpy<2.0.0' biopython scipy pdbfixer seaborn libgfortran5 tqdm jupyter ffmpeg pyrosetta fsspec py3dmol \ + pip pandas matplotlib 'numpy<2.0.0' biotite fastpdb scipy pdbfixer seaborn libgfortran5 tqdm jupyter ffmpeg pyrosetta fsspec py3dmol \ chex dm-haiku 'flax<0.10.0' dm-tree joblib ml-collections immutabledict optax \ 'jax>=0.4,<=0.6.0' 'jaxlib>=0.4,<=0.6.0=*cuda*' cuda-nvcc cudnn \ -c conda-forge -c nvidia --channel https://conda.graylab.jhu.edu -y \ || { echo -e "Error: Failed to install conda packages."; exit 1; } else $pkg_manager install \ - pip pandas matplotlib 'numpy<2.0.0' biopython scipy pdbfixer seaborn libgfortran5 tqdm jupyter ffmpeg pyrosetta fsspec py3dmol \ + pip pandas matplotlib 'numpy<2.0.0' biotite fastpdb scipy pdbfixer seaborn libgfortran5 tqdm jupyter ffmpeg pyrosetta fsspec py3dmol \ chex dm-haiku 'flax<0.10.0' dm-tree joblib ml-collections immutabledict optax \ 'jax>=0.4,<=0.6.0' 'jaxlib>=0.4,<=0.6.0' \ -c conda-forge -c nvidia --channel https://conda.graylab.jhu.edu -y \ @@ -79,7 +79,7 @@ else fi # make sure all required packages were installed -required_packages=(pip pandas libgfortran5 matplotlib numpy biopython scipy pdbfixer seaborn tqdm jupyter ffmpeg pyrosetta fsspec py3dmol chex dm-haiku dm-tree joblib ml-collections immutabledict optax jaxlib jax cuda-nvcc cudnn) +required_packages=(pip pandas libgfortran5 matplotlib numpy biotite fastpdb scipy pdbfixer seaborn tqdm jupyter ffmpeg pyrosetta fsspec py3dmol chex dm-haiku dm-tree joblib ml-collections immutabledict optax jaxlib jax cuda-nvcc cudnn) missing_packages=() # Check each package @@ -138,4 +138,4 @@ t=$SECONDS echo -e "Successfully finished BindCraft installation!\n" echo -e "Activate environment using command: \"$pkg_manager activate BindCraft\"" echo -e "\n" -echo -e "Installation took $(($t / 3600)) hours, $((($t / 60) % 60)) minutes and $(($t % 60)) seconds." \ No newline at end of file +echo -e "Installation took $(($t / 3600)) hours, $((($t / 60) % 60)) minutes and $(($t % 60)) seconds."