from __future__ import annotations
import base64
import contextlib
import cProfile
import datetime
import io
import json
import logging
import os
import pathlib
import posixpath
import pstats
import random
import sys
import time
import urllib.parse
from functools import update_wrapper, wraps
from typing import Callable
from dateutil import parser
from libb import expandabspath, grouper, splitcap
with contextlib.suppress(ImportError):
import flask
logger = logging.getLogger(__name__)
__all__ = [
'authd',
'make_url',
'appmenu',
'render_field',
'scale',
'safe_join',
'local_or_static_join',
'inject_file',
'inject_image',
'htmlquote',
'websafe',
'rsleep',
'rand_retry',
'logerror',
'COOKIE_DEFAULTS',
'JSONEncoderISODate',
'JSONDecoderISODate',
'Jinja2Render',
'ProfileMiddleware',
'get_request_dict',
'is_safe_redirect_url',
'external_url_for',
]
class _CookieDefaults(dict):
"""Default kwargs for cookielib.Cookie constructor."""
COOKIE_DEFAULTS = _CookieDefaults({
'version': 0,
'domain': '',
'domain_specified': False,
'domain_initial_dot': False,
'port': None,
'port_specified': False,
'path': '/',
'path_specified': True,
'secure': False,
'expires': None,
'discard': True,
'comment': None,
'comment_url': None,
'rest': {'HttpOnly': None},
'rfc2109': False,
})
#
# webscraping utils
#
[docs]
def rsleep(always=0, rand_extra=8):
"""Sleep for a random amount of time.
:param float always: Minimum seconds to sleep.
:param float rand_extra: Maximum additional random seconds.
"""
seconds = max(always + (random.randrange(0, max(rand_extra, 1) * 1000) * 0.001), 0)
logger.debug(f'Sleeping {seconds:0.2f} seconds ...')
time.sleep(seconds)
[docs]
def rand_retry(x_times=10, exception=Exception):
"""Decorator that retries function with random delays.
Useful for avoiding automated thresholding on web requests.
:param int x_times: Maximum number of retries.
:param exception: Exception type(s) to catch and retry on.
:returns: Decorator function.
"""
def wrapper(fn):
@wraps(fn)
def wrapped_fn(*args, **kwargs):
tries = 0
while tries <= x_times:
try:
return fn(*args, **kwargs)
except exception as err:
logger.debug(err)
tries += 1
if tries > x_times:
logger.warning(f'Retried function {x_times} times without success.')
return
logger.warning(f'Retry number {tries}')
rsleep(tries)
return wrapped_fn
return wrapper
#
# commonly reused decorators
#
[docs]
def authd(checker_fn, fallback_fn):
"""Decorator that checks if user meets an auth criterion.
:param checker_fn: Callable that returns True if authorized.
:param fallback_fn: Callable to invoke if not authorized.
:returns: Decorator function.
"""
def wrapper(f):
def authd_fn(*args, **kwargs):
if not checker_fn():
return fallback_fn()
return f(*args, **kwargs)
return update_wrapper(authd_fn, f)
return wrapper
#
# URL and path utility methods
#
[docs]
def make_url(path, **params):
"""Generate URL with query parameters.
Inspired by ``werkzeug.urls.Href``. Assumes traditional multiple params
(does not overwrite). Use ``__replace__`` to overwrite params.
Use ``__ignore__`` to filter out certain params.
:param str path: Base URL path.
:param params: Query parameters.
:returns: Complete URL with query string.
:rtype: str
Example::
>>> ignore_fn = lambda x: x.startswith('_')
>>> kw = dict(fuz=1, biz="boo")
>>> make_url('/foo/', _format='excel', __ignore__=ignore_fn, **kw)
'/foo/?fuz=1&biz=boo'
>>> make_url('/foo/?bar=1', _format='excel', **kw)
'/foo/?_format=excel&fuz=1&biz=boo&bar=1'
>>> make_url('/foo/', bar=1, baz=2)
'/foo/?bar=1&baz=2'
>>> make_url('/foo/', **{'bar':1, 'fuz':(1,2,), 'biz':"boo"})
'/foo/?bar=1&fuz=1&fuz=2&biz=boo'
>>> make_url('/foo/?a=1&a=2')
'/foo/?a=1&a=2'
>>> kwargs = dict(fuz=1, biz="boo", __ignore__=ignore_fn)
>>> xx = make_url('www.foobar.com/foo/', **kwargs)
>>> 'www' in xx and 'foobar' in xx and '/foo/' in xx and 'fuz=1' in xx and 'biz=boo' in xx
True
>>> xx = make_url('/foo/', _format='excel', **kwargs)
>>> '_format=excel' in xx
False
>>> 'fuz=1' in xx
True
>>> 'biz=boo' in xx
True
>>> yy = make_url('/foo/?bar=1', _format='excel', **kwargs)
>>> 'bar=1' in yy
True
>>> '_format=excel' in yy
False
>>> zz = make_url('/foo/', **{'bar':1, 'fuz':(1,2,), 'biz':"boo"})
>>> 'fuz=1' in zz
True
>>> 'fuz=2' in zz
True
>>> qq = make_url('/foo/?a=1&a=2')
>>> 'a=1' in qq
True
>>> 'a=2' in qq
True
"""
replace = params.pop('__replace__', {})
ignore = params.pop('__ignore__', None)
params = {k: v() if callable(v) else v for k, v in params.items() if not k.startswith('__')}
parsed = list(urllib.parse.urlparse(path))
query = urllib.parse.parse_qsl(parsed[4])
for k, v in query:
if k in params:
this = params[k]
if hasattr(this, 'append'):
this.append(v)
else:
this = [this] + [v]
params[k] = this
else:
params[k] = v
params.update(replace)
if ignore:
params = {k: v for k, v in params.items() if not ignore(k)}
parsed[4] = urllib.parse.urlencode(params, doseq=True)
cleanpath = urllib.parse.urlunparse(parsed)
return cleanpath
_os_alt_seps: list[str] = [
sep for sep in [os.sep, os.path.altsep] if sep is not None and sep != '/'
]
[docs]
def safe_join(directory: str, *pathnames: str) -> str | None:
"""Safely join untrusted path components to a base directory.
Prevents escaping the base directory via path traversal.
:param str directory: The trusted base directory.
:param pathnames: The untrusted path components relative to base.
:returns: A safe path, or ``None`` if path would escape base.
:rtype: str | None
.. note::
Via github.com/mitsuhiko/werkzeug security.py
"""
if not directory:
# Ensure we end up with ./path if directory="" is given,
# otherwise the first untrusted part could become trusted.
directory = '.'
parts = [directory]
for filename in pathnames:
if filename != '':
# normpath does not build path to root
filename = posixpath.normpath(filename)
if (any(sep in filename for sep in _os_alt_seps)
or pathlib.Path(filename).is_absolute()
or filename == '..'
or filename.startswith('../')):
return None
parts.append(filename)
return posixpath.join(*parts)
[docs]
def local_or_static_join(static, somepath):
"""Find template in working directory or static folder.
:param str static: Static folder path.
:param str somepath: Relative path to template.
:returns: Full path to existing template.
:rtype: Path
:raises OSError: If template not found in either location.
"""
localpath = expandabspath(somepath)
if localpath.exists():
return localpath
staticjoin = safe_join(static, somepath)
if staticjoin:
staticpath = pathlib.Path(staticjoin)
if staticpath.exists():
return staticpath
raise OSError('That template does not exist on your path or in the local package.')
[docs]
def inject_file(x):
"""Read file contents for injection into HTML email templates.
:param str x: Path to file (CSS, JS, etc.).
:returns: File contents.
:rtype: str
"""
with pathlib.Path(x).open(encoding=None) as f:
return f.read()
[docs]
def inject_image(x):
"""Generate base64 data URI for image embedding in HTML.
:param str x: Path to image file.
:returns: Data URI string for use in img src attribute.
:rtype: str
"""
_, ext = os.path.splitext(x)
with pathlib.Path(x).open('rb') as f:
code = base64.b64encode(f.read())
return f"data:image/{ext.strip('.')};base64,{code}"
[docs]
def scale(color, pct):
"""Scale a hex color by a percentage.
:param str color: Hex color string (e.g., '#FFF' or '#FFFFFF').
:param float pct: Scale factor (1.0 = no change).
:returns: Scaled hex color string.
:rtype: str
"""
def clamp(l, x, h):
return min(max(l, x), h)
if len(color) == 4:
r, g, b = color[1], color[2], color[3]
r += r
g += g
b += b
else:
r, g, b = color[1:3], color[3:5], color[5:]
r = int(r, 16)
g = int(g, 16)
b = int(b, 16)
r = clamp(0, int(r * pct + 0.5), 255)
g = clamp(0, int(g * pct + 0.5), 255)
b = clamp(0, int(b * pct + 0.5), 255)
return f'#{r:X}{g:X}{b:X}'
[docs]
def render_field(field) -> str:
"""Render form field with error styling.
:param field: Form field object.
:returns: HTML string for rendered field.
"""
def get_error(field):
if hasattr(field, 'note'):
return field.note
if hasattr(field, 'errors'):
return ', '.join(field.errors)
return None
html = []
error = get_error(field)
if error:
html.append(f'<span class="flderr" title="{error}">')
html.append(str(field))
if error:
html.append('</span>')
return '\n'.join(html)
[docs]
class JSONEncoderISODate(json.JSONEncoder):
"""JSON encoder that serializes dates in ISO format.
Example::
>>> JSONEncoderISODate().encode({'dt': datetime.date(2014, 10, 2)})
'{"dt": "2014-10-02"}'
"""
[docs]
def default(self, obj):
if isinstance(obj, (datetime.date, datetime.datetime)):
return obj.isoformat()
return super().default(obj)
[docs]
class JSONDecoderISODate(json.JSONDecoder):
"""JSON decoder that parses date strings into datetime objects.
Example::
>>> JSONDecoderISODate().decode('{"dt": "2014-10-02"}')
{'dt': datetime.datetime(2014, 10, 2, 0, 0)}
"""
def __init__(self, **kw):
super().__init__(object_hook=self._parse_date_hook, **kw)
def _parse_date_hook(self, obj):
if isinstance(obj, dict):
for key in obj:
if isinstance(obj[key], str):
with contextlib.suppress(ValueError, TypeError):
obj[key] = parser.parse(obj[key])
return obj
[docs]
class ProfileMiddleware:
"""WSGI middleware for profiling requests.
:param func: WSGI application callable.
:param log: Logger instance for output.
:param str sort: Profile sort key (default 'time').
:param int count: Number of top functions to show (default 20).
.. warning::
Should always be last middleware loaded:
1. You want to profile everything else
2. For speed, we return the result NOT the wrapped func
"""
def __init__(self, func, log=None, sort='time', count=20):
self.func = func
self.log = log
self.sort = sort
self.count = count
def __call__(self, environ, start_response):
stime = time.time()
pr = cProfile.Profile()
result = pr.runcall(self.func, environ, start_response)
etime = time.time() - stime
self.log.info(f'Run finished in {etime} seconds')
with io.StringIO() as s:
ps = pstats.Stats(pr, stream=s).sort_stats(self.sort)
ps.print_stats(self.count)
self.log.debug(s.getvalue())
return result
[docs]
def logerror(olderror, logger):
"""Wrap internalerror function to log tracebacks.
:param olderror: Original error handler function.
:param logger: Logger instance for error output.
:returns: Wrapped error handler.
"""
def logerror_fn():
_, exc, _ = sys.exc_info()
theerr = olderror()
if exc is not None:
logger.error(exc)
return theerr
return logerror_fn
[docs]
def htmlquote(text):
r"""Encode text for safe use in HTML.
:param str text: Text to encode.
:returns: HTML-encoded string.
:rtype: str
Example::
>>> htmlquote(u"<'&\">")
'<'&">'
"""
text = text.replace('&', '&') # Must be done first!
text = text.replace('<', '<')
text = text.replace('>', '>')
text = text.replace("'", ''')
text = text.replace('"', '"')
return text
[docs]
def websafe(val):
r"""Convert value to safe Unicode HTML string.
:param val: Value to convert (string, bytes, or None).
:returns: HTML-safe string.
:rtype: str
Example::
>>> websafe("<'&\">")
'<'&">'
>>> websafe(None)
''
>>> websafe(u'\u203d') == u'\u203d'
True
"""
if val is None:
return ''
if isinstance(val, bytes):
val = val.decode('utf-8')
elif not isinstance(val, str):
val = str(val)
return htmlquote(val)
#
# Jinja2 rendering infrastructure (Flask-compatible)
#
[docs]
class Jinja2Render:
"""Jinja2 render class.
Usage::
render = Jinja2Render('templates/')
render.add_globals({'format': fmt, 'today': datetime.date.today})
html = render('generic.html', title='Page', content=[html1, html2])
"""
[docs]
def __init__(
self,
template_dir: str,
globals: dict | None = None,
autoescape: bool = True,
) -> None:
"""Initialize Jinja2 environment with template directory.
:param template_dir: Path to template directory.
:param globals: Dict of global variables/functions for templates.
:param autoescape: Enable HTML autoescaping (default True).
"""
from jinja2 import Environment, FileSystemLoader
self.env = Environment(
loader=FileSystemLoader(template_dir),
autoescape=autoescape,
trim_blocks=True,
lstrip_blocks=True,
)
if globals:
self.env.globals.update(globals)
[docs]
def __call__(self, template_name: str, **context) -> str:
"""Render template with context - same signature as Flask's render_template.
:param template_name: Name of template file (e.g., 'generic.html').
:param context: Keyword arguments passed to template.
:returns: Rendered HTML string.
"""
template = self.env.get_template(template_name)
return template.render(**context)
[docs]
def add_globals(self, globals_dict: dict) -> None:
"""Add globals to Jinja2 environment.
:param globals_dict: Dict of globals to add.
"""
self.env.globals.update(globals_dict)
[docs]
def add_filter(self, name: str, func: callable) -> None:
"""Add custom filter to Jinja2 environment.
:param name: Filter name to use in templates.
:param func: Filter function.
"""
self.env.filters[name] = func
[docs]
def get_request_dict(**defaults) -> dict:
"""Get request parameters with defaults, supporting callables.
Example::
get_request_dict(fund='All', date=lambda: Date.today())
:param defaults: Default values for parameters. Callables are invoked.
:returns: Dict of request parameters with defaults applied.
"""
req = {}
try:
if flask.has_request_context():
req = flask.request.values.to_dict()
except (NameError, AttributeError):
pass
for key, default in defaults.items():
if key not in req or req[key] == '':
req[key] = default() if callable(default) else default
return req
[docs]
def is_safe_redirect_url(url: str) -> bool:
"""Check if redirect URL is safe (relative path only, no protocol injection).
:param url: URL to validate.
:returns: True if URL is safe for redirect.
Example::
>>> is_safe_redirect_url('/login/')
True
>>> is_safe_redirect_url('//evil.com')
False
>>> is_safe_redirect_url('https://evil.com')
False
>>> is_safe_redirect_url('')
False
"""
if not url:
return False
return url.startswith('/') and not url.startswith('//')
[docs]
def external_url_for(base_url: str, endpoint: str, **values) -> str:
"""Generate full URL with domain for external use (emails, etc).
:param base_url: Base URL including scheme and domain (e.g. 'https://app.example.com').
:param endpoint: Flask endpoint name.
:param values: URL parameters passed to url_for.
:returns: Complete URL.
"""
return urllib.parse.urljoin(base_url, flask.url_for(endpoint, **values))
if __name__ == '__main__':
__import__('doctest').testmod(optionflags=4 | 8 | 32)