Source code for bw2data.backends.peewee.database

# -*- coding: utf-8 -*-
from __future__ import print_function, unicode_literals
from eight import *

from . import sqlite3_lci_db
from ... import mapping, geomapping, config, databases, preferences
from ...errors import UntypedExchange, InvalidExchange, UnknownObject, WrongDatabase
from ...project import writable_project
from ...search import IndexManager, Searcher
from ...utils import MAX_INT_32, TYPE_DICTIONARY
from ..base import LCIBackend
from .proxies import Activity
from .schema import ActivityDataset, ExchangeDataset
from .utils import dict_as_activitydataset, dict_as_exchangedataset
from peewee import fn, DoesNotExist
import itertools
import datetime
import numpy as np
import pprint
import pyprind
import random
import sqlite3
import warnings
try:
    import cPickle as pickle
except ImportError:
    import pickle


# AD = ActivityDataset
# out_a = AD.alias()
# in_a = AD.alias()
# qs = AD.select().where(AD.key == (out_a.select(in_a.key).join(ED, on=(ED.output == out_a.key)).join(in_a, on=(ED.input_ == in_a.key)).where((ED.type_ == 'technosphere') & (in_a.product == out_a.product)))

[docs] _VALID_KEYS = {'location', 'name', 'product', 'type'}
[docs] class SQLiteBackend(LCIBackend):
[docs] backend = "sqlite"
def __init__(self, *args, **kwargs): super(SQLiteBackend, self).__init__(*args, **kwargs)
[docs] self._filters = {}
[docs] self._order_by = None
### Iteration, filtering, and ordering ###################################### # Private methods def __iter__(self): for ds in self._get_queryset(): yield Activity(ds) def __len__(self): return self._get_queryset().count() def __contains__(self, obj): return self._get_queryset(filters={'code': obj[1]}).count() > 0 @property
[docs] def _searchable(self): return databases.get(self.name, {}).get('searchable', False)
[docs] def _get_queryset(self, random=False, filters=True): qs = ActivityDataset.select().where( ActivityDataset.database == self.name) if filters: if isinstance(filters, dict): for key, value in filters.items(): qs = qs.where(getattr(ActivityDataset, key) == value) if self.filters: print("Using the following database filters:") pprint.pprint(self.filters) for key, value in self.filters.items(): qs = qs.where(getattr(ActivityDataset, key) == value) if self.order_by and not random: qs = qs.order_by(getattr(ActivityDataset, self.order_by)) else: qs = qs.order_by(fn.Random()) return qs
[docs] def _get_filters(self): return self._filters
[docs] def _set_filters(self, filters): if not filters: self._filters = {} else: print("Filters will effect all database queries" " until unset (`.filters = None`)") assert isinstance(filters, dict), "Filter must be a dictionary" for key in filters: assert key in _VALID_KEYS, \ "Filter key {} is invalid".format(key) self._filters = filters return self
[docs] def _get_order_by(self): return self._order_by
[docs] def _set_order_by(self, field): if not field: self._order_by = None else: assert field in _VALID_KEYS, \ "order_by field {} is invalid".format(field) self._order_by = field return self
# Public API
[docs] filters = property(_get_filters, _set_filters)
[docs] order_by = property(_get_order_by, _set_order_by)
[docs] def random(self, filters=True, true_random=False): """True random requires loading and sorting data in SQLite, and can be resource-intensive.""" try: if true_random: return Activity(self._get_queryset(random=True, filters=filters ).get()) else: return Activity(self._get_queryset(filters=filters ).offset(random.randint(0, len(self))).get()) except DoesNotExist: warnings.warn("This database is empty") return None
[docs] def get(self, code): return Activity( self._get_queryset(filters=False).where( ActivityDataset.code == code).get() )
### Data management ################### # Private methods
[docs] def _drop_indices(self): with sqlite3_lci_db.transaction(): sqlite3_lci_db.execute_sql('DROP INDEX IF EXISTS "activitydataset_key"') sqlite3_lci_db.execute_sql('DROP INDEX IF EXISTS "exchangedataset_input"') sqlite3_lci_db.execute_sql('DROP INDEX IF EXISTS "exchangedataset_output"')
[docs] def _add_indices(self): with sqlite3_lci_db.transaction(): sqlite3_lci_db.execute_sql('CREATE UNIQUE INDEX IF NOT EXISTS "activitydataset_key" ON "activitydataset" ("database", "code")') sqlite3_lci_db.execute_sql('CREATE INDEX IF NOT EXISTS "exchangedataset_input" ON "exchangedataset" ("input_database", "input_code")') sqlite3_lci_db.execute_sql('CREATE INDEX IF NOT EXISTS "exchangedataset_output" ON "exchangedataset" ("output_database", "output_code")')
[docs] def _efficient_write_dataset(self, index, key, ds, exchanges, activities): for exchange in ds.get('exchanges', []): if 'input' not in exchange or 'amount' not in exchange: raise InvalidExchange if 'type' not in exchange: raise UntypedExchange exchange['output'] = key exchanges.append(dict_as_exchangedataset(exchange)) # Query gets passed as INSERT INTO x VALUES ('?', '?'...) # SQLite3 has a limit of 999 variables, # So 6 fields * 125 is under the limit # Otherwise get the following: # peewee.OperationalError: too many SQL variables if len(exchanges) > 125: ExchangeDataset.insert_many(exchanges).execute() exchanges = [] ds = {k: v for k, v in ds.items() if k != "exchanges"} ds["database"] = key[0] ds["code"] = key[1] activities.append(dict_as_activitydataset(ds)) if len(activities) > 125: ActivityDataset.insert_many(activities).execute() activities = [] if not getattr(config, "is_test", None): self.pbar.update() return exchanges, activities
[docs] def _efficient_write_many_data(self, data, indices=True): be_complicated = len(data) >= 100 and indices if be_complicated: self._drop_indices() sqlite3_lci_db.db.autocommit = False try: sqlite3_lci_db.db.begin() self.delete(keep_params=True, warn=False) exchanges, activities = [], [] if not getattr(config, "is_test", None): self.pbar = pyprind.ProgBar( len(data), title="Writing activities to SQLite3 database:", monitor=True ) for index, (key, ds) in enumerate(data.items()): exchanges, activities = self._efficient_write_dataset( index, key, ds, exchanges, activities ) if not getattr(config, "is_test", None): print(self.pbar) del self.pbar if activities: ActivityDataset.insert_many(activities).execute() if exchanges: ExchangeDataset.insert_many(exchanges).execute() sqlite3_lci_db.db.commit() if len(self) > 500: sqlite3_lci_db.vacuum() except: sqlite3_lci_db.db.rollback() raise finally: sqlite3_lci_db.db.autocommit = True if be_complicated: self._add_indices()
# Public API @writable_project
[docs] def write(self, data, process=True): """Write ``data`` to database. ``data`` must be a dictionary of the form:: { ('database name', 'dataset code'): {dataset} } Writing a database will first deletes all existing data.""" if self.name not in databases: self.register() wrong_database = {key[0] for key in data}.difference({self.name}) if wrong_database: raise WrongDatabase("Can't write activities in databases {} to database {}".format( wrong_database, self.name)) databases[self.name]['number'] = len(data) databases.set_modified(self.name) mapping.add(data.keys()) if preferences.get('allow incomplete imports'): mapping.add({exc['input'] for ds in data.values() for exc in ds.get('exchanges', [])}) mapping.add({exc.get('output') for ds in data.values() for exc in ds.get('exchanges', []) if exc.get('output')}) geomapping.add({x["location"] for x in data.values() if x.get("location")}) if data: try: self._efficient_write_many_data(data) except: # Purge all data from database, then reraise self.delete(warn=False) raise self.make_searchable(reset=True) if process: self.process()
[docs] def load(self, *args, **kwargs): # Should not be used, in general; relatively slow activities = [obj['data'] for obj in self._get_queryset().dicts() ] activities = {(o['database'], o['code']): o for o in activities} for o in activities.values(): o['exchanges'] = [] exchange_qs = (ExchangeDataset.select(ExchangeDataset.data) .where(ExchangeDataset.output_database == self.name).dicts()) for exc in exchange_qs: try: activities[exc['data']['output']]['exchanges'].append(exc['data']) except KeyError: # This exchange not in the reduced set of activities returned # by _get_queryset pass return activities
[docs] def new_activity(self, code, **kwargs): obj = Activity() obj['database'] = self.name obj['code'] = str(code) obj['location'] = config.global_location obj.update(kwargs) return obj
@writable_project
[docs] def make_searchable(self, reset=False): if self.name not in databases: raise UnknownObject("This database is not yet registered") if self._searchable and not reset: print("This database is already searchable") return databases[self.name]['searchable'] = True databases.flush() IndexManager(self.filename).delete_database() IndexManager(self.filename).add_datasets(self)
@writable_project
[docs] def make_unsearchable(self): databases[self.name]['searchable'] = False databases.flush() IndexManager(self.filename).delete_database()
@writable_project
[docs] def delete(self, keep_params=False, warn=True): """Delete all data from SQLite database and Whoosh index""" if warn: MESSAGE = """ Please use `del databases['{}']` instead. Otherwise, the metadata and database get out of sync. Call `.delete(warn=False)` to skip this message in the future. """ warnings.warn(MESSAGE.format(self.name), UserWarning) ActivityDataset.delete().where(ActivityDataset.database== self.name).execute() ExchangeDataset.delete().where(ExchangeDataset.output_database== self.name).execute() IndexManager(self.filename).delete_database() if not keep_params: from ...parameters import DatabaseParameter, ActivityParameter, ParameterizedExchange groups = tuple({ o[0] for o in ActivityParameter.select( ActivityParameter.group).where( ActivityParameter.database == self.name).tuples() }) ParameterizedExchange.delete().where( ParameterizedExchange.group << groups).execute() ActivityParameter.delete().where(ActivityParameter.database == self.name).execute() DatabaseParameter.delete().where(DatabaseParameter.database == self.name).execute()
[docs] def process(self): """ Process inventory documents to NumPy structured arrays. Use a raw SQLite3 cursor instead of Peewee for a ~2 times speed advantage. """ # Get number of exchanges and processes to set # initial Numpy array size (still have to include) # implicit production exchanges num_exchanges = ExchangeDataset.select().where(ExchangeDataset.output_database == self.name).count() num_processes = ActivityDataset.select().where( ActivityDataset.database == self.name, ActivityDataset.type == "process" ).count() # Create geomapping array, from dataset keys to locations arr = np.zeros((num_processes, ), dtype=self.dtype_fields_geomapping + self.base_uncertainty_fields) def retupleize(value): if not value: return value elif "(" not in value: return value try: # Is this a dirty, dirty hack, or inspiration? # Location is retrieved as a string from the database # The alternative is to retrieve and process the # entire activity dataset... return eval(value) except NameError: return value for index, row in enumerate(ActivityDataset.select( ActivityDataset.location, ActivityDataset.code ).where( ActivityDataset.database == self.name, ActivityDataset.type == "process" ).order_by(ActivityDataset.code).dicts()): arr[index] = ( mapping[(self.name, row['code'])], geomapping[retupleize(row['location']) or config.global_location], MAX_INT_32, MAX_INT_32, 0, 1, np.NaN, np.NaN, np.NaN, np.NaN, np.NaN, False ) arr.sort(order=self.dtype_field_order( self.dtype_fields_geomapping + self.base_uncertainty_fields )) np.save(self.filepath_geomapping(), arr, allow_pickle=False) # Figure out when the production exchanges are implicit missing_production_keys = [ (self.name, x[0]) # Get all codes for x in ActivityDataset.select(ActivityDataset.code).where( # Get correct database name ActivityDataset.database == self.name, # Only consider `process` type activities ActivityDataset.type << ("process", None), # But exclude activities that already have production exchanges ~(ActivityDataset.code << ExchangeDataset.select( # Get codes to exclude ExchangeDataset.output_code).where( ExchangeDataset.output_database == self.name, ExchangeDataset.type == 'production' ) ) ).tuples() ] arr = np.zeros((num_exchanges + len(missing_production_keys), ), dtype=self.dtype) # Using raw sqlite3 to retrieve data for ~2x speed boost connection = sqlite3.connect(sqlite3_lci_db._filepath) cursor = connection.cursor() SQL = "SELECT data, input_database, input_code, output_database, output_code FROM exchangedataset WHERE output_database = ?" dependents = set() found_exchanges = False for index, row in enumerate(cursor.execute(SQL, (self.name,))): data, input_database, input_code, output_database, output_code = row data = pickle.loads(bytes(data)) if "type" not in data: raise UntypedExchange if "amount" not in data or "input" not in data: raise InvalidExchange if np.isnan(data['amount']) or np.isinf(data['amount']): raise ValueError("Invalid amount in exchange {}".format(data)) found_exchanges = True dependents.add(input_database) try: arr[index] = ( mapping[(input_database, input_code)], mapping[(output_database, output_code)], MAX_INT_32, MAX_INT_32, TYPE_DICTIONARY[data["type"]], data.get("uncertainty type", 0), data["amount"], data["amount"] \ if data.get("uncertainty type", 0) in (0,1) \ else data.get("loc", np.NaN), data.get("scale", np.NaN), data.get("shape", np.NaN), data.get("minimum", np.NaN), data.get("maximum", np.NaN), data["amount"] < 0 ) except KeyError: raise UnknownObject(("Exchange between {} and {} is invalid " "- one of these objects is unknown (i.e. doesn't exist " "as a process dataset)" ).format( (input_database, input_code), (output_database, output_code) ) ) # If exchanges were found, start inserting rows at len(exchanges) + 1 index = index + 1 if found_exchanges else 0 for index, key in zip(itertools.count(index), missing_production_keys): arr[index] = ( mapping[key], mapping[key], MAX_INT_32, MAX_INT_32, TYPE_DICTIONARY["production"], 0, 1, 1, np.NaN, np.NaN, np.NaN, np.NaN, False ) databases[self.name]['depends'] = sorted(dependents.difference({self.name})) databases[self.name]['processed'] = datetime.datetime.now().isoformat() databases.flush() arr.sort(order=self.dtype_field_order()) np.save(self.filepath_processed(), arr, allow_pickle=False)
[docs] def search(self, string, **kwargs): """Search this database for ``string``. The searcher include the following fields: * name * comment * categories * location * reference product ``string`` can include wild cards, e.g. ``"trans*"``. By default, the ``name`` field is given the most weight. The full weighting set is called the ``boost`` dictionary, and the default weights are:: { "name": 5, "comment": 1, "product": 3, "categories": 2, "location": 3 } Optional keyword arguments: * ``limit``: Number of results to return. * ``boosts``: Dictionary of field names and numeric boosts - see default boost values above. New values must be in the same format, but with different weights. * ``filter``: Dictionary of criteria that search results must meet, e.g. ``{'categories': 'air'}``. Keys must be one of the above fields. * ``mask``: Dictionary of criteria that exclude search results. Same format as ``filter``. * ``facet``: Field to facet results. Must be one of ``name``, ``product``, ``categories``, ``location``, or ``database``. * ``proxy``: Return ``Activity`` proxies instead of raw Whoosh documents. Default is ``True``. Returns a list of ``Activity`` datasets.""" with Searcher(self.filename) as s: results = s.search(string, **kwargs) return results
[docs] def graph_technosphere(self, filename=None, **kwargs): from bw2analyzer.matrix_grapher import SparseMatrixGrapher from bw2calc import LCA lca = LCA({self.random(): 1}) lca.lci() smg = SparseMatrixGrapher(lca.technosphere_matrix) return smg.ordered_graph(filename, **kwargs)