Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

selectGroupSlices stage should 🏃 before Select samples stage #3852

Merged
merged 2 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions app/packages/core/src/components/Actions/utils.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ export const tagStatistics = selectorFamily<
? {
id: modal ? get(groupId) : null,
slices: get(fos.currentSlices(modal)),
slice: get(fos.currentSlice(modal)),
mode: get(groupStatistics(modal)),
}
: null,
Expand Down
1 change: 0 additions & 1 deletion app/packages/looker/src/elements/common/controls.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ export class ControlsElement<
error,
loaded,
}: Readonly<State>) {
console.log("showcontrols is ", showControls);
showControls = showControls && !disableControls && !error && loaded;
if (this.showControls === showControls) {
return this.element;
Expand Down
1 change: 1 addition & 0 deletions app/packages/state/src/hooks/useSetExpandedSample.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ export default () => {
) => {
set(groupAtoms.groupId, groupId || null);
set(currentModalSample, { id, index });

reset(groupAtoms.dynamicGroupIndex);
reset(dynamicGroupCurrentElementIndex);
groupByFieldValue &&
Expand Down
2 changes: 1 addition & 1 deletion app/packages/state/src/recoil/modal.ts
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ export const modalSampleIndex = selector<number>({
key: "modalSampleIndex",
get: ({ get }) => {
const current = get(currentModalSample);

if (!current) {
throw new Error("modal sample is not defined");
}
Expand Down Expand Up @@ -180,6 +179,7 @@ export const modalSample = graphQLSelector<
},
variables: ({ get }) => {
const current = get(currentModalSample);

if (current === null) return null;

const slice = get(groupSlice);
Expand Down
6 changes: 6 additions & 0 deletions fiftyone/core/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -8504,6 +8504,12 @@ def repr_ViewExpression(self, expr, level):
MatchTags,
Select,
SelectBy,
SelectGroupSlices,
Skip,
Take,
}

# Registry of select stages that should select first
_STAGES_THAT_SELECT_FIRST = {
SelectGroupSlices,
}
4 changes: 4 additions & 0 deletions fiftyone/core/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -1820,6 +1820,10 @@ def make_optimized_select_view(
else:
if view.media_type == fom.GROUP and view.group_slices and flatten:
view = view.select_group_slices(_allow_mixed=True)
else:
for stage in stages:
if type(stage) in fost._STAGES_THAT_SELECT_FIRST:
view = view._add_view_stage(stage, validate=False)

view = view.select(sample_ids, ordered=ordered)

Expand Down
1 change: 0 additions & 1 deletion fiftyone/server/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def get_view(
pagination_data=pagination_data,
extended_stages=extended_stages,
)

return view


Expand Down
50 changes: 33 additions & 17 deletions tests/unittests/view_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4257,23 +4257,7 @@ def test_view_field_copy(self):
self.assertEqual(str(field), str(deepcopy(field)))

def test_make_optimized_select_view_group_dataset(self):
dataset = fo.Dataset()
dataset.add_group_field("group", default="center")

groups = ["left", "center", "right"]
filepaths = [
[str(i) + str(j) + ".jpg" for i in groups] for j in range(3)
]
filepaths = [dict(zip(groups, fps)) for fps in zip(*filepaths)]
group = fo.Group()
samples = []
for fps in filepaths:
for name, filepath in fps.items():
samples.append(
fo.Sample(filepath=filepath, group=group.element(name))
)

sample_ids = dataset.add_samples(samples)
dataset, sample_ids = self._make_group_dataset()

optimized_view = fov.make_optimized_select_view(
dataset, sample_ids[0], flatten=True
Expand All @@ -4284,6 +4268,22 @@ def test_make_optimized_select_view_group_dataset(self):
]
self.assertEqual(optimized_view._all_stages, expected_stages)

def test_make_optimized_select_view_select_group_slices_before_sample_selection(
self,
):
dataset, sample_ids = self._make_group_dataset()
view = dataset.select_group_slices(["left", "right"])

optimized_view = fov.make_optimized_select_view(
view,
sample_ids[1],
)

first_stage, second_stage = optimized_view._stages
# the order matters
self.assertEqual(type(first_stage), fosg.SelectGroupSlices)
self.assertEqual(type(second_stage), fosg.Select)

def test_selected_samples_in_group_slices(self):
(dataset, selected_ids) = self._make_group_by_group_dataset()
view = dataset.view()
Expand All @@ -4307,6 +4307,22 @@ def test_selected_samples_in_group_slices(self):
)
self.assertEqual(len(optimized_view), 2)

def _make_group_dataset(self):
dataset = fo.Dataset()
dataset.add_group_field("group", default="left")
groups = ["left", "right"]
filepaths = [str(i) + ".jpg" for i in groups]

filepaths = [dict(zip(groups, fps)) for fps in zip(*filepaths)]
group = fo.Group()
samples = []
for fps in filepaths:
for name, filepath in fps.items():
samples.append(
fo.Sample(filepath=filepath, group=group.element(name))
)
return dataset, dataset.add_samples(samples)

def _make_group_by_group_dataset(self):
dataset = fo.Dataset()
dataset.add_group_field("group_field", default="left")
Expand Down
Loading