Skip to content

Commit

Permalink
Improve TransformMatchingString to match longest common substrings by…
Browse files Browse the repository at this point in the history
… default
  • Loading branch information
3b1b committed Feb 4, 2023
1 parent b8fe7b0 commit fab917c
Showing 1 changed file with 51 additions and 19 deletions.
70 changes: 51 additions & 19 deletions manimlib/animation/transform_matching_parts.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
from __future__ import annotations

import itertools as it

import numpy as np
from difflib import SequenceMatcher

from manimlib.animation.composition import AnimationGroup
from manimlib.animation.fading import FadeInFromPoint
from manimlib.animation.fading import FadeOutToPoint
from manimlib.animation.fading import FadeTransformPieces
from manimlib.animation.transform import Transform
from manimlib.mobject.mobject import Mobject
from manimlib.mobject.mobject import Group
from manimlib.mobject.svg.string_mobject import StringMobject
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.mobject.svg.string_mobject import StringMobject

from typing import TYPE_CHECKING

Expand Down Expand Up @@ -131,28 +127,64 @@ def __init__(
target: StringMobject,
matched_keys: Iterable[str] = [],
key_map: dict[str, str] = dict(),
matched_pairs: Iterable[tuple[Mobject, Mobject]] = [],
matched_pairs: Iterable[tuple[VMobject, VMobject]] = [],
**kwargs,
):
matched_pairs = list(matched_pairs) + [
*[(source[key], target[key]) for key in matched_keys],
*[(source[key1], target[key2]) for key1, key2 in key_map.items()],
*[
(source[substr], target[substr])
for substr in [
*source.get_specified_substrings(),
*target.get_specified_substrings(),
*source.get_symbol_substrings(),
*target.get_symbol_substrings(),
]
]
matched_pairs = [
*matched_pairs,
*self.matching_blocks(source, target, matched_keys, key_map),
]

super().__init__(
source, target,
matched_pairs=matched_pairs,
**kwargs,
)

def matching_blocks(
self,
source: StringMobject,
target: StringMobject,
matched_keys: Iterable[str],
key_map: dict[str, str]
) -> list[tuple[VMobject, VMobject]]:
syms1 = source.get_symbol_substrings()
syms2 = target.get_symbol_substrings()
counts1 = list(map(source.substr_to_path_count, syms1))
counts2 = list(map(target.substr_to_path_count, syms2))

# Start with user specified matches
blocks = [(source[key], target[key]) for key in matched_keys]
blocks += [(source[key1], target[key2]) for key1, key2 in key_map.items()]

# Nullify any intersections with those matches in the two symbol lists
for sub_source, sub_target in blocks:
for i in range(len(syms1)):
if source[i] in sub_source.family_members_with_points():
syms1[i] = "Null1"
for j in range(len(syms2)):
if target[j] in sub_target.family_members_with_points():
syms2[j] = "Null2"

# Group together longest matching substrings
while True:
matcher = SequenceMatcher(None, syms1, syms2)
match = matcher.find_longest_match(0, len(syms1), 0, len(syms2))
if match.size == 0:
break

i1 = sum(counts1[:match.a])
i2 = sum(counts2[:match.b])
size = sum(counts1[match.a:match.a + match.size])

blocks.append((source[i1:i1 + size], target[i2:i2 + size]))

for i in range(match.size):
syms1[match.a + i] = "Null1"
syms2[match.b + i] = "Null2"

return blocks


class TransformMatchingTex(TransformMatchingStrings):
"""Alias for TransformMatchingStrings"""
Expand Down

0 comments on commit fab917c

Please sign in to comment.