Source code for bw2parameters.interpreter

from collections.abc import Iterable
from numbers import Number

import numpy as np
from asteval import Interpreter as ASTInterpreter
from asteval import NameFinder

from .errors import MissingName
from .pint import PintWrapper


[docs] class Interpreter(ASTInterpreter): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] self.BUILTIN_SYMBOLS = set(self.symtable)
@classmethod
[docs] def is_numeric(cls, value): return isinstance(value, (Number, np.ndarray))
[docs] def _raise_missing_name(func): # noqa def wrapper(self, expr, *args, **kwargs): try: return func(self, expr, *args, **kwargs) except (NameError, SyntaxError): raise MissingName(expr) return wrapper
@_raise_missing_name
[docs] def get_symbols(self, text): """ Parses an expression and returns all symbols. """ if text is None: return set() nf = NameFinder() nf.generic_visit(self.parse(text)) return set(nf.names)
[docs] def get_unknown_symbols( self, text, known_symbols=None, ignore_symtable=False, no_pint_units=None, ): """ Parses an expression and returns all symbols which are neither in the symtable nor passed via known_symbols. """ if text is None: return set() if known_symbols is None: known_symbols = set() elif isinstance(known_symbols, Iterable): known_symbols = set(known_symbols) else: raise ValueError( f"Parameter known_symbols must be iterable. Is {type(known_symbols)}." ) if not ignore_symtable: known_symbols = known_symbols.union(set(self.symtable.keys())) all_symbols = set(self.get_symbols(text)) return all_symbols.difference(known_symbols)
[docs] def add_symbols(self, symbols): """Adds symbols to the symtable.""" if symbols is None: return self.symtable.update(symbols)
[docs] def remove_symbols(self, symbols): """Removes symbols from the symtable.""" if symbols is None: return if isinstance(symbols, dict): symbols = set(symbols) for symbol in symbols: self.symtable.pop(symbol)
[docs] def user_defined_symbols(self): return set(self.symtable).difference(self.BUILTIN_SYMBOLS)
@_raise_missing_name
[docs] def eval(self, expr, *args, known_symbols=None, raise_errors=True, **kwargs): self.add_symbols(known_symbols) result = super().eval(expr=expr, *args, raise_errors=raise_errors, **kwargs) self.remove_symbols(known_symbols) return result
@classmethod
[docs] def parameter_list_to_dict(cls, param_list): return {d["name"]: d["amount"] for d in param_list}
@classmethod
[docs] def is_quantity(cls, value): return False
@classmethod
[docs] def is_quantity_from_same_registry(cls, value): return False
@classmethod
[docs] def get_unit_dimensionality( cls, unit_name=None ): # signature must be same for Interpreter and PintInterpreter # noqa return dict()
@classmethod
[docs] def set_amount_and_unit(cls, obj, quantity, to_unit=None): obj["amount"] = quantity
[docs] class PintInterpreter(Interpreter): def __init__(self, *args, units=None, **kwargs): super().__init__(*args, **kwargs) if units is not None: self.add_symbols(PintWrapper.to_units(units, raise_errors=True)) @classmethod
[docs] def is_numeric(cls, value): return super().is_numeric(value) or isinstance( value, PintWrapper.GeneralQuantity )
[docs] def parse(self, text): return super().parse(PintWrapper.string_preprocessor(text))
[docs] def get_unknown_symbols( self, text, known_symbols=None, ignore_symtable=False, include_pint_units=False, no_pint_units=None, ): """Parses the given expression and returns a list of symbols, which are neither contained in the symtable, nor in known_symbols, nor can be interpreted as pint units""" unknown_symbols = super().get_unknown_symbols( text=text, known_symbols=known_symbols, ignore_symtable=ignore_symtable, ) # exclude symbols which can be parsed as pint units and are not in `no_pint_units` if not include_pint_units: pint_units = PintWrapper.to_units( unknown_symbols, raise_errors=False, drop_none=True ) # exclude explicitly defined symbols pint_units = set(pint_units).difference(no_pint_units or set()) unknown_symbols = unknown_symbols.difference(pint_units) return unknown_symbols
[docs] def get_pint_symbols(self, text, known_symbols=None, ignore_symtable=True): """ Parses an expression and returns all symbols which can be interpreted as pint units. """ if text is None: return dict() # get all unknown symbols unknown_symbols = super().get_unknown_symbols( text=text, known_symbols=known_symbols, ignore_symtable=ignore_symtable, ) # filter those which can be interpreted as a pint.Unit pint_symbols = PintWrapper.to_units( unknown_symbols, raise_errors=False, drop_none=True ) return pint_symbols
@classmethod
[docs] def is_quantity(cls, value): return PintWrapper.is_quantity(value)
@classmethod
[docs] def is_quantity_from_same_registry(cls, value): return PintWrapper.is_quantity_from_same_registry(value)
@classmethod
[docs] def get_unit_dimensionality(cls, unit_name=None): return PintWrapper.get_dimensionality(unit_name)
[docs] def add_symbols(self, symbols): """ Adds symbols to symtable while making sure that pint Quantities are from same registry as self.ureg (otherwise self.eval will fail). """ if symbols is None: return for k, v in symbols.items(): # if value is a quantity from another unit registry -> convert to current unit registry if PintWrapper.is_quantity( v ) and not PintWrapper.is_quantity_from_same_registry(v): symbols[k] = PintWrapper.Quantity(value=v.m, units=v.u) super().add_symbols(symbols=symbols)
[docs] def _raise_proper_pint_exception(func): # noqa """Make sure that pint exceptions are correctly raised during evaluation""" def wrapper(self, expr, *args, **kwargs): try: return func(self, expr, *args, **kwargs) # noqa except TypeError: try: PintWrapper.ureg.parse_expression( expr, **self.symtable ) # will raise proper exception except Exception as error: error.extra_msg = f": {expr}" raise error from None # omit previous exceptions return wrapper
@_raise_proper_pint_exception # noqa
[docs] def eval(self, expr, *args, known_symbols=None, **kwargs): pint_symbols = self.get_pint_symbols( text=expr, known_symbols=known_symbols, ignore_symtable=False ) self.add_symbols(pint_symbols) result = super().eval(expr=expr, known_symbols=known_symbols, *args, **kwargs) return result
@classmethod
[docs] def parameter_list_to_dict(cls, param_list): """ Takes a list of parameter objects and returns a dict where keys are the parameter names and values are the interpreted pint.Quantities (or float where no unit is defined). """ return { d["name"]: PintWrapper.to_quantity( amount=d["amount"], unit=d.get("unit") or d.get("data", {}).get("unit") ) for d in param_list }
@classmethod
[docs] def set_amount_and_unit(cls, obj, quantity=None, to_unit=None): """ Takes an arbitrary object and tries to set it's `amount` and `unit` fields. `amount` field is the magnitude of the pint.Quantity after conversion to `to_unit`. \ If no `to_unit` is given, the quantity's own unit will be used. If the input is not a pint.Quantity then `obj['unit']` will be used. If no quantity is given, then `obj['amount']` and `obj['unit']` are used. """ is_quantity = cls.is_quantity(quantity) amount = quantity.m if is_quantity else quantity or obj.get("amount") unit = str(quantity.u) if is_quantity else obj.get("unit") or to_unit if amount is None: return if unit is None: obj["amount"] = amount return to_unit = to_unit or unit if unit == to_unit: obj["amount"] = amount obj["unit"] = unit else: quantity = ( quantity if is_quantity else PintWrapper.to_quantity(amount, unit) ) obj["amount"] = quantity.to(to_unit).m obj["unit"] = to_unit