Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(debug): fix --debug flag and associated tests #552

Merged
merged 1 commit into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 28 additions & 19 deletions safety/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,32 @@

LOG = logging.getLogger(__name__)

def preprocess_args(f):
if '--debug' in sys.argv:
index = sys.argv.index('--debug')
if len(sys.argv) > index + 1:
next_arg = sys.argv[index + 1]
if next_arg in ('1', 'true'):
sys.argv.pop(index + 1) # Remove the next argument (1 or true)
return f

def configure_logger(ctx, param, debug):
level = logging.CRITICAL

if debug:
level = logging.DEBUG

logging.basicConfig(format='%(asctime)s %(name)s => %(message)s', level=level)
logging.basicConfig(format='%(asctime)s %(name)s => %(message)s', level=level)

@click.group(cls=SafetyCLILegacyGroup, help=CLI_MAIN_INTRODUCTION, epilog=DEFAULT_EPILOG)
@auth_options()
@proxy_options
@click.option('--disable-optional-telemetry', default=False, is_flag=True, show_default=True, help=CLI_DISABLE_OPTIONAL_TELEMETRY_DATA_HELP)
@click.option('--debug', default=False, help=CLI_DEBUG_HELP, callback=configure_logger)
@click.option('--debug', is_flag=True, help=CLI_DEBUG_HELP, callback=configure_logger)
@click.version_option(version=get_safety_version())
@click.pass_context
@inject_session
@preprocess_args
def cli(ctx, debug, disable_optional_telemetry):
"""
Scan and secure Python projects against package vulnerabilities. To get started navigate to a Python project and run `safety scan`.
Expand Down Expand Up @@ -107,7 +116,7 @@ def inner(ctx, *args, **kwargs):
kwargs.pop('proxy_protocol', None)
kwargs.pop('proxy_host', None)
kwargs.pop('proxy_port', None)

if ctx.get_parameter_source("json_version") != click.core.ParameterSource.DEFAULT and not (
save_json or json or output == 'json'):
raise click.UsageError(
Expand All @@ -128,8 +137,8 @@ def inner(ctx, *args, **kwargs):
proxy_dictionary=None)
audit_and_monitor = (audit_and_monitor and server_audit_and_monitor)

kwargs.update({"auto_remediation_limit": auto_remediation_limit,
"policy_file":policy_file,
kwargs.update({"auto_remediation_limit": auto_remediation_limit,
"policy_file":policy_file,
"audit_and_monitor": audit_and_monitor})

except SafetyError as e:
Expand Down Expand Up @@ -441,18 +450,18 @@ def validate(ctx, name, version, path):
if not os.path.exists(path):
click.secho(f'The path "{path}" does not exist.', fg='red', file=sys.stderr)
sys.exit(EXIT_CODE_FAILURE)

if version not in ["3.0", "2.0", None]:
click.secho(f'Version "{version}" is not a valid value, allowed values are 3.0 and 2.0. Use --path to specify the target file.', fg='red', file=sys.stderr)
sys.exit(EXIT_CODE_FAILURE)

def fail_validation(e):
click.secho(str(e).lstrip(), fg='red', file=sys.stderr)
sys.exit(EXIT_CODE_FAILURE)

if not version:
version = "3.0"

result = ""

if version == "3.0":
Expand All @@ -463,7 +472,7 @@ def fail_validation(e):
policy = load_policy_file(Path(path))
except Exception as e:
fail_validation(e)

click.secho(f"The Safety policy ({version}) file " \
"(Used for scan and system-scan commands) " \
"was successfully parsed " \
Expand All @@ -478,18 +487,18 @@ def fail_validation(e):
sys.exit(EXIT_CODE_FAILURE)

del values['raw']

result = json.dumps(values, indent=4, default=str)

click.secho("The Safety policy file " \
"(Valid only for the check command) " \
"was successfully parsed with the " \
"following values:", fg="green")

console.print_json(result)


@cli.command(cls=SafetyCLILegacyCommand,
@cli.command(cls=SafetyCLILegacyCommand,
help=CLI_CONFIGURE_HELP,
utility_command=True)
@click.option("--proxy-protocol", "-pr", type=click.Choice(['http', 'https']), default='https', cls=DependentOption,
Expand Down Expand Up @@ -519,8 +528,8 @@ def fail_validation(e):
@click.option("--save-to-system/--save-to-user", default=False, is_flag=True,
help=CLI_CONFIGURE_SAVE_TO_SYSTEM)
@click.pass_context
def configure(ctx, proxy_protocol, proxy_host, proxy_port, proxy_timeout,
proxy_required, organization_id, organization_name, stage,
def configure(ctx, proxy_protocol, proxy_host, proxy_port, proxy_timeout,
proxy_required, organization_id, organization_name, stage,
save_to_system):
"""
Configure global settings, like proxy settings and organization details
Expand Down Expand Up @@ -565,7 +574,7 @@ def configure(ctx, proxy_protocol, proxy_host, proxy_port, proxy_timeout,
'host': proxy_host,
'port': str(proxy_port)
})

if not config.has_section(PROXY_SECTION_NAME):
config.add_section(PROXY_SECTION_NAME)

Expand Down Expand Up @@ -669,7 +678,7 @@ def check_updates(ctx: typer.Context,

if not data:
raise SafetyException("No data found.")

console.print("[green]Safety CLI is authenticated:[/green]")

from rich.padding import Padding
Expand All @@ -696,7 +705,7 @@ def check_updates(ctx: typer.Context,
f"If Safety was installed from a requirements file, update Safety to version {latest_available_version} in that requirements file."
)
console.print()
# `pip -i <source_url> install safety=={latest_available_version}` OR
# `pip -i <source_url> install safety=={latest_available_version}` OR
console.print(f"Pip: To install the updated version of Safety directly via pip, run: `pip install safety=={latest_available_version}`")

if console.quiet:
Expand All @@ -717,5 +726,5 @@ def check_updates(ctx: typer.Context,

cli.add_command(alert)

if __name__ == "__main__":
if __name__ == "__main__":
cli()
52 changes: 51 additions & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
import logging
import os
import sys
import shutil
import tempfile
import unittest
Expand All @@ -15,6 +17,8 @@
from safety import cli
from safety.models import CVE, SafetyRequirement, Severity, Vulnerability
from safety.util import Package, SafetyContext
from safety.auth.models import Auth
from safety_schemas.models.base import AuthenticationType


def get_vulnerability(vuln_kwargs=None, cve_kwargs=None, pkg_kwargs=None):
Expand Down Expand Up @@ -513,4 +517,50 @@ def test_license_with_file(self, fetch_database_url):
test_filename = os.path.join(dirname, "reqs_4.txt")
result = self.runner.invoke(cli.cli, ['license', '--key', 'foo', '--file', test_filename])
print(result.stdout)
self.assertEqual(result.exit_code, 0)
self.assertEqual(result.exit_code, 0)

@patch('safety.auth.cli.get_auth_info', return_value={'email': 'test@test.com'})
@patch.object(Auth, 'is_valid', return_value=True)
@patch('safety.auth.utils.SafetyAuthSession.get_authentication_type', return_value=AuthenticationType.TOKEN)
@patch('builtins.input', lambda *args: '')
@patch('safety.safety.fetch_database', return_value={'vulnerable_packages': []})
def test_debug_flag(self, mock_get_auth_info, mock_is_valid, mock_get_auth_type, mock_fetch_database):
result = self.runner.invoke(cli.cli, ['--debug', 'scan'])
assert result.exit_code == 0, f"CLI exited with code {result.exit_code} and output: {result.output} and error: {result.stderr}"
assert "for known security issues using default" in result.output

@patch('safety.auth.cli.get_auth_info', return_value={'email': 'test@test.com'})
@patch.object(Auth, 'is_valid', return_value=True)
@patch('safety.auth.utils.SafetyAuthSession.get_authentication_type', return_value=AuthenticationType.TOKEN)
@patch('builtins.input', lambda *args: '')
@patch('safety.safety.fetch_database', return_value={'vulnerable_packages': []})
def test_debug_flag_with_value_1(self, mock_get_auth_info, mock_is_valid, mock_get_auth_type, mock_fetch_database):
sys.argv = ['safety', '--debug', '1', 'scan']

@cli.preprocess_args
def dummy_function():
pass

# Extract the preprocessed arguments from sys.argv
preprocessed_args = sys.argv[1:] # Exclude the script name 'safety'

# Assert the preprocessed arguments
assert preprocessed_args == ['--debug', 'scan'], f"Preprocessed args: {preprocessed_args}"

@patch('safety.auth.cli.get_auth_info', return_value={'email': 'test@test.com'})
@patch.object(Auth, 'is_valid', return_value=True)
@patch('safety.auth.utils.SafetyAuthSession.get_authentication_type', return_value=AuthenticationType.TOKEN)
@patch('builtins.input', lambda *args: '')
@patch('safety.safety.fetch_database', return_value={'vulnerable_packages': []})
def test_debug_flag_with_value_true(self, mock_get_auth_info, mock_is_valid, mock_get_auth_type, mock_fetch_database):
sys.argv = ['safety', '--debug', 'true', 'scan']

@cli.preprocess_args
def dummy_function():
pass

# Extract the preprocessed arguments from sys.argv
preprocessed_args = sys.argv[1:] # Exclude the script name 'safety'

# Assert the preprocessed arguments
assert preprocessed_args == ['--debug', 'scan'], f"Preprocessed args: {preprocessed_args}"
Loading