Skip to content

Commit

Permalink
Add --factory-drawer option to console script
Browse files Browse the repository at this point in the history
  • Loading branch information
SmileyChris committed Feb 3, 2022
1 parent 2287382 commit 172a172
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 18 deletions.
87 changes: 74 additions & 13 deletions qrcode/console_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
When stdout is a tty the QR Code is printed to the terminal and when stdout is
a pipe to a file an image is written. The default image format is PNG.
"""
import sys
import optparse
import os
import sys
from typing import Iterable, Optional, Type

import qrcode
from qrcode.image.base import BaseImage, DrawerAliases

# The next block is added to get the terminal to display properly on MS platforms
if sys.platform.startswith(("win", "cygwin")): # pragma: no cover
Expand Down Expand Up @@ -43,9 +46,11 @@ def main(args=None):
"--factory",
help="Full python path to the image factory class to "
"create the image with. You can use the following shortcuts to the "
"built-in image factory classes: {}.".format(
", ".join(sorted(default_factories.keys()))
),
f"built-in image factory classes: {commas(default_factories)}.",
)
parser.add_option(
"--factory-drawer",
help=f"Use an alternate drawer. {get_drawer_help()}.",
)
parser.add_option(
"--optimize",
Expand Down Expand Up @@ -73,18 +78,21 @@ def main(args=None):

opts, args = parser.parse_args(args)

qr = qrcode.QRCode(error_correction=error_correction[opts.error_correction])

if opts.factory:
module = default_factories.get(opts.factory, opts.factory)
if "." not in module:
parser.error("The image factory is not a full python path")
module, name = module.rsplit(".", 1)
imp = __import__(module, {}, [], [name])
image_factory = getattr(imp, name)
try:
image_factory = get_factory(module)
except ValueError as e:
parser.error(str(e))
image_factory = None
else:
image_factory = None

qr = qrcode.QRCode(
error_correction=error_correction[opts.error_correction],
image_factory=image_factory,
)

if args:
data = args[0]
data = data.encode(errors="surrogateescape")
Expand All @@ -99,15 +107,32 @@ def main(args=None):
qr.add_data(data, optimize=opts.optimize)

if opts.output:
img = qr.make_image(image_factory=image_factory)
img = qr.make_image()
with open(opts.output, "wb") as out:
img.save(out)
else:
if image_factory is None and (os.isatty(sys.stdout.fileno()) or opts.ascii):
qr.print_ascii(tty=not opts.ascii)
return

img = qr.make_image(image_factory=image_factory)
kwargs = {}
aliases: Optional[DrawerAliases] = getattr(
qr.image_factory, "drawer_aliases", None
)
if aliases and opts.factory_drawer:
if not aliases:
parser.error(f"The selected factory has no drawer aliases.")
if opts.factory_drawer not in aliases:
parser.error(
f"{opts.factory_drawer} factory drawer not found. Expected {commas(aliases)}"
)
drawer_cls, drawer_kwargs = aliases[opts.factory_drawer]
kwargs["module_drawer"] = drawer_cls(**drawer_kwargs)
elif opts.factory == "svg-circles":
from qrcode.image.styles.moduledrawers.svg import SvgCircleDrawer

kwargs["module_drawer"] = SvgCircleDrawer()
img = qr.make_image(**kwargs)

sys.stdout.flush()
# Use sys.stdout.buffer if available (Python 3), avoiding
Expand All @@ -123,5 +148,41 @@ def main(args=None):
img.save(stdout_buffer)


def get_factory(module: str) -> Type[BaseImage]:
if "." not in module:
raise ValueError("The image factory is not a full python path")
module, name = module.rsplit(".", 1)
imp = __import__(module, {}, {}, [name])
return getattr(imp, name)


def get_drawer_help() -> str:
help = {}
for alias, module in default_factories.items():
try:
image = get_factory(module)
except ImportError:
continue
aliases: Optional[DrawerAliases] = getattr(image, "drawer_aliases", None)
if not aliases:
continue
factories = help.setdefault(commas(aliases), set())
factories.add(alias)

return ". ".join(
f"For {commas(factories, 'and')}, use: {aliases}"
for aliases, factories in help.items()
)


def commas(items: Iterable[str], joiner="or") -> str:
items = tuple(items)
if not items:
return ""
if len(items) == 1:
return items[0]
return f"{', '.join(items[:-1])} {joiner} {items[-1]}"


if __name__ == "__main__":
main()
9 changes: 7 additions & 2 deletions qrcode/image/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import abc
from typing import TYPE_CHECKING, Any, Optional, Tuple, Type, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, Union

from qrcode.image.styles.moduledrawers.base import QRModuleDrawer

if TYPE_CHECKING:
from qrcode.image.styles.moduledrawers.base import QRModuleDrawer
from qrcode.main import ActiveWithNeighbors, QRCode


DrawerAliases = Dict[str, Tuple[Type[QRModuleDrawer], Dict[str, Any]]]


class BaseImage:
"""
Base QRCode image output class.
Expand Down Expand Up @@ -103,6 +107,7 @@ def is_eye(self, row: int, col: int):

class BaseImageWithDrawer(BaseImage):
default_drawer_class: "Type[QRModuleDrawer]"
drawer_aliases: DrawerAliases = {}

def get_default_module_drawer(self) -> "QRModuleDrawer":
return self.default_drawer_class()
Expand Down
3 changes: 3 additions & 0 deletions qrcode/image/styles/moduledrawers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ class QRModuleDrawer(abc.ABC):

needs_neighbors = False

def __init__(self, **kwargs):
pass

def initialize(self, img: "BaseImage") -> None:
self.img = img

Expand Down
22 changes: 19 additions & 3 deletions qrcode/image/svg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from typing import List, Literal, Optional, Type, Union, overload

import qrcode.image.base
from qrcode.image.styles.moduledrawers import svg as svg_drawers
from qrcode.image.styles.moduledrawers.base import QRModuleDrawer
from qrcode.image.styles.moduledrawers.svg import SvgPathSquareDrawer, SvgSquareDrawer

try:
import lxml.etree as ET
Expand All @@ -22,7 +22,7 @@ class SvgFragmentImage(qrcode.image.base.BaseImageWithDrawer):
_SVG_namespace = "http://www.w3.org/2000/svg"
kind = "SVG"
allowed_kinds = ("SVG",)
default_drawer_class: Type[QRModuleDrawer] = SvgSquareDrawer
default_drawer_class: Type[QRModuleDrawer] = svg_drawers.SvgSquareDrawer

def __init__(self, *args, **kwargs):
ET.register_namespace("svg", self._SVG_namespace)
Expand Down Expand Up @@ -84,6 +84,11 @@ class SvgImage(SvgFragmentImage):
"""

background: Optional[str] = None
drawer_aliases = {
"circle": (svg_drawers.SvgCircleDrawer, {}),
"gapped-circle": (svg_drawers.SvgCircleDrawer, {"size_ratio": Decimal(0.8)}),
"gapped-square": (svg_drawers.SvgSquareDrawer, {"size_ratio": Decimal(0.8)}),
}

def _svg(self, tag="svg", **kwargs):
svg = super()._svg(tag=tag, **kwargs)
Expand Down Expand Up @@ -120,7 +125,18 @@ class SvgPathImage(SvgImage):

needs_processing = True
path: ET.Element = None
default_drawer_class: Type[QRModuleDrawer] = SvgPathSquareDrawer
default_drawer_class: Type[QRModuleDrawer] = svg_drawers.SvgPathSquareDrawer
drawer_aliases = {
"circle": (svg_drawers.SvgPathCircleDrawer, {}),
"gapped-circle": (
svg_drawers.SvgPathCircleDrawer,
{"size_ratio": Decimal(0.8)},
),
"gapped-square": (
svg_drawers.SvgPathSquareDrawer,
{"size_ratio": Decimal(0.8)},
),
}

def __init__(self, *args, **kwargs):
self._subpaths: List[str] = []
Expand Down

0 comments on commit 172a172

Please sign in to comment.