Skip to content

Commit

Permalink
Have matrix keep track of elements and ellipses as lists instead of V…
Browse files Browse the repository at this point in the history
…Groups
  • Loading branch information
3b1b committed Feb 13, 2024
1 parent ed3ac74 commit 7b577e9
Showing 1 changed file with 41 additions and 28 deletions.
69 changes: 41 additions & 28 deletions manimlib/mobject/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,30 +47,42 @@ def __init__(
matrix, v_buff, h_buff, element_alignment_corner,
**element_config
)

# Create helpful groups for the elements
n_cols = len(self.mob_matrix[0])
self.elements = VGroup(*it.chain(*self.mob_matrix))
self.elements = [elem for row in self.mob_matrix for elem in row]
self.columns = VGroup(*(
VGroup(*(row[i] for row in self.mob_matrix))
for i in range(n_cols)
))
self.rows = VGroup(*(VGroup(*row) for row in self.mob_matrix))
self.ellipses = VGroup()
if height is not None:
self.rows.set_height(height - 2 * bracket_v_buff)
self.brackets = self.create_brackets(self.rows, bracket_v_buff, bracket_h_buff)
self.ellipses = []

# Add elements and brackets
self.elements.center()
self.add(self.elements)
if height is not None:
self.set_height(height - 2 * bracket_v_buff)
self.add_brackets(bracket_v_buff, bracket_h_buff)
self.add(self.ellipses)
self.add(*self.elements)
self.add(*self.brackets)
self.center()

# Potentially add ellipses
self.swap_entries_for_ellipses(
ellipses_row,
ellipses_col,
)

def copy(self, deep: bool = False):
result = super().copy(deep)
self_family = self.get_family()
copy_family = result.get_family()
for attr in ["elements", "ellipses"]:
setattr(result, attr, [
copy_family[self_family.index(mob)]
for mob in getattr(self, attr)
])
return result

def create_mobject_matrix(
self,
matrix: GenericMatrixType,
Expand Down Expand Up @@ -106,22 +118,18 @@ def element_to_mobject(self, element, **config) -> VMobject:
else:
return Tex(str(element), **config)

def add_brackets(self, v_buff: float, h_buff: float) -> Self:
height = len(self.mob_matrix)
def create_brackets(self, rows, v_buff: float, h_buff: float) -> VGroup:
brackets = Tex("".join((
R"\left[\begin{array}{c}",
*height * [R"\quad \\"],
*len(rows) * [R"\quad \\"],
R"\end{array}\right]",
)))
brackets.set_height(self.get_height() + v_buff)
brackets.set_height(rows.get_height() + v_buff)
l_bracket = brackets[:len(brackets) // 2]
r_bracket = brackets[len(brackets) // 2:]
l_bracket.next_to(self, LEFT, h_buff)
r_bracket.next_to(self, RIGHT, h_buff)
brackets.set_submobjects([l_bracket, r_bracket])
self.brackets = VGroup(l_bracket, r_bracket)
self.add(*brackets)
return self
l_bracket.next_to(rows, LEFT, h_buff)
r_bracket.next_to(rows, RIGHT, h_buff)
return VGroup(l_bracket, r_bracket)

def get_column(self, index: int):
if not 0 <= index < len(self.columns):
Expand Down Expand Up @@ -150,6 +158,14 @@ def add_background_to_entries(self) -> Self:
mob.add_background_rectangle()
return self

def swap_entry_for_dots(self, entry, dots):
dots.move_to(entry)
entry.become(dots)
if entry in self.elements:
self.elements.remove(entry)
if entry not in self.ellipses:
self.ellipses.append(entry)

def swap_entries_for_ellipses(
self,
row_index: Optional[int] = None,
Expand All @@ -169,24 +185,18 @@ def swap_entries_for_ellipses(
use_vdots = row_index is not None and -len(rows) <= row_index < len(rows)
use_hdots = col_index is not None and -len(cols) <= col_index < len(cols)

def swap_entry_for_dots(entry, dots):
dots.move_to(entry)
entry.become(dots)
self.elements.remove(entry)
self.ellipses.add(entry)

if use_vdots:
for column in cols:
# Add vdots
dots = Tex(R"\vdots")
dots.set_height(vdots_height)
swap_entry_for_dots(column[row_index], dots)
self.swap_entry_for_dots(column[row_index], dots)
if use_hdots:
for row in rows:
# Add hdots
dots = Tex(R"\hdots")
dots.set_width(hdots_width)
swap_entry_for_dots(row[col_index], dots)
self.swap_entry_for_dots(row[col_index], dots)
if use_vdots and use_hdots:
rows[row_index][col_index].rotate(-45 * DEGREES)
return self
Expand All @@ -195,10 +205,13 @@ def get_mob_matrix(self) -> VMobjectMatrixType:
return self.mob_matrix

def get_entries(self) -> VGroup:
return self.elements
return VGroup(*self.elements)

def get_brackets(self) -> VGroup:
return self.brackets
return VGroup(*self.brackets)

def get_ellipses(self) -> VGroup:
return VGroup(*self.ellipses)


class DecimalMatrix(Matrix):
Expand Down

0 comments on commit 7b577e9

Please sign in to comment.