import os
from peewee import SqliteDatabase
from bw2data import projects
from bw2data.search.schema import BW2Schema
[docs]
class IndexManager:
def __init__(self, database_path):
[docs]
self.path = os.path.join(projects.request_directory("search"), database_path)
[docs]
self.db = SqliteDatabase(self.path)
with self.db.connection_context():
if not os.path.exists(self.path) or len(self.db.get_tables()) == 0:
self.create()
[docs]
def get(self):
return self
[docs]
def create(self):
self.delete_database()
with self.db.connection_context():
with self.db.bind_ctx(MODELS):
self.db.create_tables(MODELS)
[docs]
def add_dataset(self, ds):
self.add_datasets([ds])
[docs]
def add_datasets(self, datasets):
all_dataset = list(datasets)
with self.db.connection_context():
with self.db.bind_ctx(MODELS):
for chunk_range in range(0, len(datasets), 100):
for model in MODELS:
model.insert_many(
[
self._format_dataset(ds)
for ds in all_dataset[chunk_range : chunk_range + 100]
]
).execute()
[docs]
def update_dataset(self, ds):
with self.db.connection_context():
with self.db.bind_ctx(MODELS):
for model in MODELS:
model.delete().where(
model.code == ds["code"], model.database == ds["database"]
).execute()
model.insert(**self._format_dataset(ds)).execute()
[docs]
def delete_dataset(self, ds):
with self.db.connection_context():
with self.db.bind_ctx(MODELS):
for model in MODELS:
model.delete().where(
model.code == ds["code"], model.database == ds["database"]
).execute()
[docs]
def delete_database(self):
with self.db.connection_context():
with self.db.bind_ctx(MODELS):
self.db.drop_tables(MODELS)
[docs]
def close(self):
self.db.close()
@staticmethod
[docs]
def escape_search_for_fts5(string: str) -> str:
# Use SQL Escape syntax to properly escape specials chars
# see https://stackoverflow.com/a/43756146
#
# We split each word and enclose it inside "" to escape the specials characters that it may contain
# If the word was ending with a * wildcard, we keep it but outside the escaped string
#
# `Com* cheese cow's` is converted to `"com"* "cheese" "cow's"`
return " ".join(
[
f'"{term[:-1]}"*' if term.endswith("*") else f'"{term}"'
for term in string.replace('"', "").split()
]
)
[docs]
def search(self, string, limit=None, weights=None, mask=None, filter=None):
with self.db.connection_context():
with self.db.bind_ctx(MODELS):
if string == "*":
query = BW2Schema
else:
query = BW2Schema.search_bm25(
self.escape_search_for_fts5(string),
weights=weights,
)
# Skip SQL-level limit when post-filtering is needed so we don't
# prematurely discard results that would survive the filter/mask step.
sql_limit = None if (filter or mask) else limit
results = list(
query.select(
BW2Schema.name,
BW2Schema.comment,
BW2Schema.product,
BW2Schema.categories,
BW2Schema.synonyms,
BW2Schema.location,
BW2Schema.database,
BW2Schema.code,
)
.limit(sql_limit)
.dicts()
.execute()
)
if filter:
normalized = {k: v.lower() if isinstance(v, str) else v for k, v in filter.items()}
results = [
r for r in results
if all(str(normalized[k]) in r.get(k, "") for k in normalized)
]
if mask:
normalized = {k: v.lower() if isinstance(v, str) else v for k, v in mask.items()}
results = [
r for r in results
if not any(str(normalized[k]) in r.get(k, "") for k in normalized)
]
if limit is not None:
results = results[:limit]
return results