Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions src/datajoint/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import re
import warnings
from contextlib import contextmanager
from getpass import getpass
from typing import Callable

import pymysql as client
Expand Down Expand Up @@ -110,9 +109,9 @@ def conn(
host : str, optional
Database hostname.
user : str, optional
MySQL username.
Database username. Required if not set in config.
password : str, optional
MySQL password. Prompts if not provided.
Database password. Required if not set in config.
init_fun : callable, optional
Initialization function called after connection.
reset : bool, optional
Expand All @@ -125,15 +124,24 @@ def conn(
-------
Connection
Persistent database connection.

Raises
------
DataJointError
If user or password is not provided and not set in config.
"""
if not hasattr(conn, "connection") or reset:
host = host if host is not None else config["database.host"]
user = user if user is not None else config["database.user"]
password = password if password is not None else config["database.password"]
if user is None:
user = input("Please enter DataJoint username: ")
raise errors.DataJointError(
"Database user not configured. Set datajoint.config['database.user'] or pass user= argument."
)
if password is None:
password = getpass(prompt="Please enter DataJoint password: ")
raise errors.DataJointError(
"Database password not configured. Set datajoint.config['database.password'] or pass password= argument."
)
init_fun = init_fun if init_fun is not None else config["connection.init_function"]
use_tls = use_tls if use_tls is not None else config["database.use_tls"]
conn.connection = Connection(host, user, password, None, init_fun, use_tls)
Expand Down
10 changes: 6 additions & 4 deletions src/datajoint/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ class Schema:
create_schema : bool, optional
If False, raise error if schema doesn't exist. Default True.
create_tables : bool, optional
If False, raise error when accessing missing tables. Default True.
If False, raise error when accessing missing tables.
Default from ``dj.config.database.create_tables`` (True unless configured).
add_objects : dict, optional
Additional objects for the declaration context.

Expand All @@ -93,7 +94,7 @@ def __init__(
*,
connection: Connection | None = None,
create_schema: bool = True,
create_tables: bool = True,
create_tables: bool | None = None,
add_objects: dict[str, Any] | None = None,
) -> None:
"""
Expand All @@ -110,15 +111,16 @@ def __init__(
create_schema : bool, optional
If False, raise error if schema doesn't exist. Default True.
create_tables : bool, optional
If False, raise error when accessing missing tables. Default True.
If False, raise error when accessing missing tables.
Default from ``dj.config.database.create_tables`` (True unless configured).
add_objects : dict, optional
Additional objects for the declaration context.
"""
self.connection = connection
self.database = None
self.context = context
self.create_schema = create_schema
self.create_tables = create_tables
self.create_tables = create_tables if create_tables is not None else config.database.create_tables
self.add_objects = add_objects
self.declare_list = []
if schema_name:
Expand Down
14 changes: 14 additions & 0 deletions src/datajoint/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
"database.user": "DJ_USER",
"database.password": "DJ_PASS",
"database.port": "DJ_PORT",
"database.schema_prefix": "DJ_SCHEMA_PREFIX",
"database.create_tables": "DJ_CREATE_TABLES",
"loglevel": "DJ_LOG_LEVEL",
}

Expand Down Expand Up @@ -185,6 +187,18 @@ class DatabaseSettings(BaseSettings):
port: int = Field(default=3306, validation_alias="DJ_PORT")
reconnect: bool = True
use_tls: bool | None = None
schema_prefix: str = Field(
default="",
validation_alias="DJ_SCHEMA_PREFIX",
description="Project-specific prefix for schema names. "
"Not automatically applied; use dj.config.database.schema_prefix when creating schemas.",
)
create_tables: bool = Field(
default=True,
validation_alias="DJ_CREATE_TABLES",
description="Default for Schema create_tables parameter. "
"Set to False for production mode to prevent automatic table creation.",
)


class ConnectionSettings(BaseSettings):
Expand Down
2 changes: 1 addition & 1 deletion src/datajoint/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# version bump auto managed by Github Actions:
# label_prs.yaml(prep), release.yaml(bump), post_release.yaml(edit)
# manually set this version will be eventually overwritten by the above actions
__version__ = "2.0.0a22"
__version__ = "2.0.0a24"
Loading