Skip to content

Commit

Permalink
Fix file type checks in data splits for contrastive training example …
Browse files Browse the repository at this point in the history
…script (#31720)

fix data split file type checks
  • Loading branch information
npyoung committed Jul 10, 2024
1 parent e9eeeda commit a0a3e2f
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions examples/pytorch/contrastive-image-text/run_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,9 @@ def __post_init__(self):
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension == "json", "`validation_file` should be a json file."
if self.test_file is not None:
extension = self.test_file.split(".")[-1]
assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."


dataset_name_mapping = {
Expand Down
6 changes: 3 additions & 3 deletions examples/tensorflow/contrastive-image-text/run_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,9 @@ def __post_init__(self):
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension == "json", "`validation_file` should be a json file."
if self.test_file is not None:
extension = self.test_file.split(".")[-1]
assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."


dataset_name_mapping = {
Expand Down

0 comments on commit a0a3e2f

Please sign in to comment.