Source code for bw2data.search.indices

import os

from peewee import SqliteDatabase

from bw2data import projects
from bw2data.search.schema import BW2Schema

[docs] MODELS = (BW2Schema,)
[docs] class IndexManager: def __init__(self, database_path):
[docs] self.path = os.path.join(projects.request_directory("search"), database_path)
[docs] self.db = SqliteDatabase(self.path)
with self.db.connection_context(): if not os.path.exists(self.path) or len(self.db.get_tables()) == 0: self.create()
[docs] def get(self): return self
[docs] def create(self): self.delete_database() with self.db.connection_context(): with self.db.bind_ctx(MODELS): self.db.create_tables(MODELS)
[docs] def _format_dataset(self, ds): def _fix_location(string): if isinstance(string, tuple): string = string[1] if isinstance(string, str): if string.lower() == "none": return "" else: return string.lower().strip() else: return "" return dict( name=(ds.get("name") or "").lower(), comment=(ds.get("comment") or "").lower(), product=(ds.get("reference product") or "").lower(), categories=", ".join(ds.get("categories") or []).lower(), synonyms=", ".join(ds.get("synonyms") or []).lower(), location=_fix_location(ds.get("location") or ""), database=ds["database"], code=ds["code"], )
[docs] def add_dataset(self, ds): self.add_datasets([ds])
[docs] def add_datasets(self, datasets): all_dataset = list(datasets) with self.db.connection_context(): with self.db.bind_ctx(MODELS): for chunk_range in range(0, len(datasets), 100): for model in MODELS: model.insert_many( [ self._format_dataset(ds) for ds in all_dataset[chunk_range : chunk_range + 100] ] ).execute()
[docs] def update_dataset(self, ds): with self.db.connection_context(): with self.db.bind_ctx(MODELS): for model in MODELS: model.delete().where( model.code == ds["code"], model.database == ds["database"] ).execute() model.insert(**self._format_dataset(ds)).execute()
[docs] def delete_dataset(self, ds): with self.db.connection_context(): with self.db.bind_ctx(MODELS): for model in MODELS: model.delete().where( model.code == ds["code"], model.database == ds["database"] ).execute()
[docs] def delete_database(self): with self.db.connection_context(): with self.db.bind_ctx(MODELS): self.db.drop_tables(MODELS)
[docs] def close(self): self.db.close()
@staticmethod
[docs] def escape_search_for_fts5(string: str) -> str: # Use SQL Escape syntax to properly escape specials chars # see https://stackoverflow.com/a/43756146 # # We split each word and enclose it inside "" to escape the specials characters that it may contain # If the word was ending with a * wildcard, we keep it but outside the escaped string # # `Com* cheese cow's` is converted to `"com"* "cheese" "cow's"` return " ".join( [ f'"{term[:-1]}"*' if term.endswith("*") else f'"{term}"' for term in string.replace('"', "").split() ] )
[docs] def search(self, string, limit=None, weights=None, mask=None, filter=None): with self.db.connection_context(): with self.db.bind_ctx(MODELS): if string == "*": query = BW2Schema else: query = BW2Schema.search_bm25( self.escape_search_for_fts5(string), weights=weights, ) # Skip SQL-level limit when post-filtering is needed so we don't # prematurely discard results that would survive the filter/mask step. sql_limit = None if (filter or mask) else limit results = list( query.select( BW2Schema.name, BW2Schema.comment, BW2Schema.product, BW2Schema.categories, BW2Schema.synonyms, BW2Schema.location, BW2Schema.database, BW2Schema.code, ) .limit(sql_limit) .dicts() .execute() ) if filter: normalized = {k: v.lower() if isinstance(v, str) else v for k, v in filter.items()} results = [ r for r in results if all(str(normalized[k]) in r.get(k, "") for k in normalized) ] if mask: normalized = {k: v.lower() if isinstance(v, str) else v for k, v in mask.items()} results = [ r for r in results if not any(str(normalized[k]) in r.get(k, "") for k in normalized) ] if limit is not None: results = results[:limit] return results