"""
A basic object relational mapper for SQLite.
This is a very simple ORM implementation which contains only functionality needed by
Maestral. Many operations will still require explicit SQL statements. This module is no
alternative to fully featured ORMs such as sqlalchemy but may be useful when system
memory is constrained.
"""
from __future__ import annotations
from weakref import WeakValueDictionary
from typing import Any, Generator, TypeVar, Generic, Union, Optional, cast, overload
from .core import Database
from .query import Query
from .types import SqlType, SqlEnum
SQLSafeType = Union[str, int, float, None]
T = TypeVar("T")
ST = TypeVar("ST")
M = TypeVar("M", bound="Model")
__all__ = [
"Column",
"NonNullColumn",
"NoDefault",
"Manager",
"Model",
]
[docs]
class NoDefault:
"""
Class to denote the absence of a default value.
This is distinct from ``None`` which may be a valid default.
"""
[docs]
class Column(Generic[T, ST]):
"""
Represents a column in a database table.
:param sql_type: Column type in database table. Python types which don't have SQLite
equivalents, such as :class:`enum.Enum`, will be converted appropriately.
:param unique: If ``True``, sets a unique constraint on the column.
:param primary_key: If ``True``, marks this column as a primary key column.
Currently, only a single primary key column is supported.
:param index: If ``True``, create an index on this column.
:param default: Default value for the column. Set to :class:`NoDefault` if no
default value should be used. Note than None / NULL is a valid default for an
SQLite column.
"""
def __init__(
self,
sql_type: SqlType[T, ST],
unique: bool = False,
primary_key: bool = False,
index: bool = False,
default: T | type[NoDefault] | None = None,
):
self.type = sql_type
self.unique = unique
self.primary_key = primary_key
self.index = index
self.name = ""
self.default: T | type[NoDefault] | None = default
def __set_name__(self, owner: Any, name: str) -> None:
self.name = name
self.private_name = "_" + name
@overload
def __get__(self, obj: None, objtype: type | None = None) -> Column[T, ST]:
...
@overload
def __get__(self, obj: Any, objtype: type | None = None) -> T | None:
...
def __get__(
self, obj: Any, objtype: type | None = None
) -> Column[T, ST] | T | None:
if obj is None:
return self
if self.default is NoDefault:
res = getattr(obj, self.private_name)
else:
res = getattr(obj, self.private_name, self.default)
return cast(Optional[T], res)
def __set__(self, obj: Any, value: T) -> None:
setattr(obj, self.private_name, value)
[docs]
def render_constraints(self) -> str:
"""Returns a string with constraints for the SQLite column definition."""
constraints = []
if isinstance(self.type, SqlEnum):
# Mypy type narrowing does not work well with generics.
# See https://github.com/python/mypy/issues/12060.
values = ", ".join(
repr(member.name) for member in self.type.enum_type # type:ignore
)
constraints.append(f"CHECK( {self.name} IN ({values}) )")
if self.unique:
constraints.append("UNIQUE")
return " ".join(constraints)
[docs]
def render_properties(self) -> str:
"""Returns a string with properties for the SQLite column definition."""
properties = []
if self.primary_key:
properties.append("PRIMARY KEY")
if self.default in (None, NoDefault):
properties.append("DEFAULT NULL")
else:
properties.append(f"DEFAULT {repr(self.default)}")
return " ".join(properties)
[docs]
def render_column(self) -> str:
"""Returns a string with the full SQLite column definition."""
return " ".join(
[
self.name,
self.type.sql_type,
self.render_constraints(),
self.render_properties(),
]
)
[docs]
def py_to_sql(self, value: T | None) -> ST | None:
"""
Converts a Python value to a value which can be stored in the database column.
:param value: Native Python value.
:returns: Converted Python value to store in database. Will only return str,
int, float or None.
"""
if value is None:
return value
return self.type.py_to_sql(value)
[docs]
def sql_to_py(self, value: ST | None) -> T | None:
"""
Converts a database column value to the original Python type.
:param value: Value from database column. Only accepts str, int, float or None.
:returns: Converted Python value.
"""
if value is None:
return value
return self.type.sql_to_py(value)
[docs]
class NonNullColumn(Column[T, ST]):
"""Subclass of :class:`Column` which is not nullable, i.e., does not accept or
return None as a value."""
def __init__(
self,
sql_type: SqlType[T, ST],
unique: bool = False,
primary_key: bool = False,
index: bool = False,
default: T | type[NoDefault] = NoDefault,
):
super().__init__(sql_type, unique, primary_key, index, default)
def __set__(self, obj: Any, value: T | None) -> None:
setattr(obj, self.private_name, value)
@overload
def __get__(self, obj: None, objtype: type | None = None) -> Column[T, ST]:
...
@overload
def __get__(self, obj: Any, objtype: type | None = None) -> T:
...
def __get__(self, obj: Any, objtype: type | None = None) -> Column[T, ST] | T:
res = super().__get__(obj, objtype)
return cast(T, res)
[docs]
def py_to_sql(self, value: T | None) -> ST:
if value is None:
raise ValueError("This column does not allow NULL values")
return self.type.py_to_sql(value)
[docs]
def sql_to_py(self, value: ST | None) -> T:
if value is None:
raise ValueError("Unexpected value None / NULL")
return self.type.sql_to_py(value)
[docs]
def render_constraints(self) -> str:
constraints = super().render_constraints()
return f"{constraints} NOT NULL"
[docs]
class Manager(Generic[M]):
"""
A data mapper interface for a table model.
Creates the table as defined in the model if it doesn't already exist. Keeps a cache
of weak references to all retrieved and created rows to speed up queries. The cache
should be cleared manually changes where made to the table from outside this
manager.
:param db: Database to use.
:param model: Model for database table.
"""
def __init__(self, db: Database, model: type[M]) -> None:
self.db = db
self.model = model
self.table_name = model.__tablename__
self.pk_column = next(col for col in model.__columns__ if col.primary_key)
self._cache: WeakValueDictionary[SQLSafeType, M] = WeakValueDictionary()
# Precompute often-used SQL query strings.
self._columns = model.__columns__
column_names = [col.name for col in self._columns]
column_names_str = ", ".join(column_names)
column_refs = ", ".join(["?"] * len(self._columns))
self._sql_insert_template = "INSERT INTO {} ({}) VALUES ({})".format(
self.table_name, column_names_str, column_refs
)
where_expressions = [f"{name} = ?" for name in column_names]
where_expressions_str = ", ".join(where_expressions)
self._sql_update_template = "UPDATE {} SET {} WHERE {} = ?".format(
self.table_name,
where_expressions_str,
self.pk_column.name,
)
self.create_table_if_not_exists()
[docs]
def create_table_if_not_exists(self) -> None:
"""Creates the table as defined by the model."""
column_defs = [col.render_column() for col in self.model.__columns__]
column_defs_str = ", ".join(column_defs)
sql = f"CREATE TABLE IF NOT EXISTS {self.table_name} ({column_defs_str});"
self.db.executescript(sql)
for column in self.model.__columns__:
if column.index:
table_name_stripped = self.table_name.strip("'\"")
idx_name = f"idx_{table_name_stripped}_{column.name}"
sql = f"CREATE INDEX IF NOT EXISTS {idx_name} ON {self.table_name} ({column.name});"
self.db.executescript(sql)
self._did_create_table = True
[docs]
def clear_cache(self) -> None:
"""Clears our cache."""
self._cache.clear()
[docs]
def delete(self, query: Query) -> None:
clause, args = query.clause()
sql = f"DELETE FROM {self.table_name} WHERE {clause}"
self.db.execute(sql, *args)
self.clear_cache()
[docs]
def select(self, query: Query) -> list[M]:
clause, args = query.clause()
sql = f"SELECT * FROM {self.table_name} WHERE {clause}"
result = self.db.execute(sql, *args)
return [self._item_from_kwargs(**row) for row in result.fetchall()]
[docs]
def select_iter(
self, query: Query, size: int = 1000
) -> Generator[list[M], Any, None]:
clause, args = query.clause()
sql = f"SELECT * FROM {self.table_name} WHERE {clause}"
result = self.db.execute(sql, *args)
rows = result.fetchmany(size)
while len(rows) > 0:
yield [self._item_from_kwargs(**row) for row in rows]
rows = result.fetchmany(size)
[docs]
def select_sql(self, sql: str, *args: Any) -> list[M]:
"""
Performs the given SQL query and converts any returned rows to model objects.
:param sql: SQL statement to execute.
:param args: Parameters to substitute for placeholders in SQL statement.
:returns: List of model objects from the query.
"""
result = self.db.execute(f"SELECT * FROM {self.table_name} {sql}", *args)
return [self._item_from_kwargs(**row) for row in result.fetchall()]
[docs]
def delete_primary_key(self, primary_key: Any) -> None:
"""
Delete a model object / row from database by primary key.
:param primary_key: Primary key for row.
"""
pk_sql = self.pk_column.py_to_sql(primary_key)
sql = f"DELETE from {self.table_name} WHERE {self.pk_column.name} = ?"
self.db.execute(sql, pk_sql)
try:
del self._cache[pk_sql]
except KeyError:
pass
[docs]
def get(self, primary_key: Any) -> M | None:
"""
Gets a model object from database by its primary key. This will return a cached
value if available and None if no row with the primary key exists.
:param primary_key: Primary key for row.
:returns: Model object representing the row.
"""
pk_sql = self.pk_column.py_to_sql(primary_key)
try:
return self._cache[pk_sql]
except KeyError:
pass
sql = f"SELECT * FROM {self.table_name} WHERE {self.pk_column.name} = ?"
result = self.db.execute(sql, pk_sql)
row = result.fetchone()
if not row:
return None
return self._item_from_kwargs(**row)
[docs]
def has(self, primary_key: Any) -> bool:
"""
Checks if a model object exists in database by its primary key
:param primary_key: The primary key.
:returns: Whether the corresponding row exists in the table.
"""
pk_sql = self.pk_column.py_to_sql(primary_key)
sql = f"SELECT {self.pk_column.name} FROM {self.table_name} WHERE {self.pk_column.name} = ?"
result = self.db.execute(sql, pk_sql)
return bool(result.fetchone())
[docs]
def save(self, obj: M) -> M:
"""
Saves a model object to the database table. If the primary key is None, a new
primary key will be generated by SQLite on inserting the row. This key will be
retrieved and stored in the primary key property of the object.
:param obj: Model object to save.
:returns: Saved model object.
"""
pk_sql = self._get_primary_key(obj)
if self.has(pk_sql):
raise ValueError(f"Object with primary key {pk_sql} is already registered")
sql_values = (col.py_to_sql(getattr(obj, col.name)) for col in self._columns)
self.db.execute(self._sql_insert_template, *sql_values)
if pk_sql is None:
# Round trip to fetch created primary key.
res = self.db.execute("SELECT last_insert_rowid()").fetchone()
pk_sql = res["last_insert_rowid()"]
pk_py = self.pk_column.sql_to_py(pk_sql)
setattr(obj, self.pk_column.name, pk_py)
self._cache[pk_sql] = obj
return obj
[docs]
def update(self, obj: M) -> None:
"""
Updates the database table from a model object.
:param obj: The object to update.
"""
pk_sql = self._get_primary_key(obj)
if pk_sql is None:
raise ValueError("Primary key is required to update row")
if self.has(pk_sql):
sql_vals = (col.py_to_sql(getattr(obj, col.name)) for col in self._columns)
self.db.execute(self._sql_update_template, *(list(sql_vals) + [pk_sql]))
else:
self.save(obj)
[docs]
def count(self) -> int:
"""Returns the number of rows in the table."""
res = self.db.execute(f"SELECT COUNT(*) FROM {self.table_name};")
counts = res.fetchone()
return cast(int, counts[0])
[docs]
def clear(self) -> None:
"""Delete all rows from table."""
self.db.execute(f"DROP TABLE {self.table_name}")
self.clear_cache()
self.create_table_if_not_exists()
def _get_primary_key(self, obj: M) -> SQLSafeType:
"""
Returns the primary key value for a model object / row in the table.
:param obj: Model instance which represents the row.
:returns: Primary key for row.
"""
pk_py = getattr(obj, self.pk_column.name)
return self.pk_column.py_to_sql(pk_py)
def _item_from_kwargs(self, **kwargs: Any) -> M:
"""
Create a model object from SQL column values
:param kwargs: Column values.
:returns: Model object.
"""
# Convert any types as appropriate.
for key, value in kwargs.items():
col = getattr(self.model, key)
kwargs[key] = col.sql_to_py(value)
obj = self.model(**kwargs)
pk_sql = self._get_primary_key(obj)
self._cache[pk_sql] = obj
return obj
class ModelBase(type):
def __new__(
mcs, cls_name: str, bases: tuple[type], namespace: dict[str, Any], **kwargs: Any
) -> ModelBase:
columns: list[Column[Any, Any]] = []
slots: list[str] = []
# Find all columns in namespace.
for name, value in namespace.items():
if isinstance(value, Column):
columns.append(value)
slots.append(f"_{name}")
# Add __columns__ attribute to namespace.
namespace["__columns__"] = frozenset(columns)
# Add slots to namespace if we have declared columns. Otherwise, don't set slots
# because this prevents subclasses from having weakrefs.
if slots:
namespace["__slots__"] = slots
return super().__new__(mcs, cls_name, bases, namespace, **kwargs)
[docs]
class Model(metaclass=ModelBase):
"""
Abstract object model to represent an SQL table.
Instances of this class are model objects which correspond to rows in the database
table.
To define a table, subclass :class:`Model` and define class properties as
:class:`Column`. Override the ``__tablename__`` attribute with the SQLite table name
to use. The ``__columns__`` attribute will be populated automatically for you.
"""
__tablename__: str
"""The name of the database table"""
__columns__: frozenset[Column[Any, Any]]
"""The columns of the database table"""
def __init__(self, **kwargs: Any) -> None:
"""
Initialise with keyword arguments corresponding to column names and values.
:param kwargs: Keyword arguments assigning values to table columns.
"""
columns_names = {col.name for col in self.__columns__}
missing_columns = {
c.name for c in self.__columns__ if isinstance(c, NonNullColumn)
}
for name, value in kwargs.items():
missing_columns.discard(name)
if name in columns_names:
setattr(self, name, value)
else:
raise TypeError(f"{self.__class__.__name__} has no column '{name}'")
if len(missing_columns) > 0:
raise TypeError(f"Column values required for {missing_columns}")
def __repr__(self) -> str:
attributes = ", ".join(
f"{col.name}={getattr(self, col.name)}" for col in self.__columns__
)
return f"<{self.__class__.__name__}({attributes})>"