Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SSL Bucket, Light Logic Refactor and Docstring Update for Alibaba Provider #23891

Merged
merged 12 commits into from
May 31, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 34 additions & 36 deletions airflow/providers/alibaba/cloud/hooks/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,12 @@ def provide_bucket_name(func: T) -> T:
def wrapper(*args, **kwargs) -> T:
bound_args = function_signature.bind(*args, **kwargs)
self = args[0]
if 'bucket_name' not in bound_args.arguments or bound_args.arguments['bucket_name'] is None:
if self.oss_conn_id:
connection = self.get_connection(self.oss_conn_id)
if connection.schema:
bound_args.arguments['bucket_name'] = connection.schema
if (
'bucket_name' not in bound_args.arguments or bound_args.arguments['bucket_name'] is None
) and self.oss_conn_id:
vincentkoc marked this conversation as resolved.
Show resolved Hide resolved
connection = self.get_connection(self.oss_conn_id)
if connection.schema:
bound_args.arguments['bucket_name'] = connection.schema

return func(*bound_args.args, **bound_args.kwargs)

Expand Down Expand Up @@ -92,10 +93,7 @@ class OSSHook(BaseHook):
def __init__(self, region: Optional[str] = None, oss_conn_id='oss_default', *args, **kwargs) -> None:
self.oss_conn_id = oss_conn_id
self.oss_conn = self.get_connection(oss_conn_id)
if region is None:
self.region = self.get_default_region()
else:
self.region = region
self.region = self.get_default_region() if region is None else region
super().__init__(*args, **kwargs)

def get_conn(self) -> "Connection":
Expand Down Expand Up @@ -148,7 +146,7 @@ def get_bucket(self, bucket_name: Optional[str] = None) -> oss2.api.Bucket:
"""
auth = self.get_credential()
assert self.region is not None
return oss2.Bucket(auth, 'http://oss-' + self.region + '.aliyuncs.com', bucket_name)
return oss2.Bucket(auth, f'https://oss-{self.region}.aliyuncs.com', bucket_name)

@provide_bucket_name
@unify_bucket_name_and_key
Expand All @@ -163,7 +161,7 @@ def load_string(self, key: str, content: str, bucket_name: Optional[str] = None)
try:
self.get_bucket(bucket_name).put_object(key, content)
except Exception as e:
raise AirflowException(f"Errors: {e}")
raise AirflowException(f"Errors: {e}") from e
vincentkoc marked this conversation as resolved.
Show resolved Hide resolved

@provide_bucket_name
@unify_bucket_name_and_key
Expand All @@ -183,7 +181,7 @@ def upload_local_file(
try:
self.get_bucket(bucket_name).put_object_from_file(key, file)
except Exception as e:
raise AirflowException(f"Errors when upload file: {e}")
raise AirflowException(f"Errors when upload file: {e}") from e

@provide_bucket_name
@unify_bucket_name_and_key
Expand Down Expand Up @@ -226,7 +224,7 @@ def delete_object(
self.get_bucket(bucket_name).delete_object(key)
except Exception as e:
self.log.error(e)
raise AirflowException(f"Errors when deleting: {key}")
raise AirflowException(f"Errors when deleting: {key}") from e

@provide_bucket_name
@unify_bucket_name_and_key
Expand All @@ -245,7 +243,7 @@ def delete_objects(
self.get_bucket(bucket_name).batch_delete_objects(key)
except Exception as e:
self.log.error(e)
raise AirflowException(f"Errors when deleting: {key}")
raise AirflowException(f"Errors when deleting: {key}") from e

@provide_bucket_name
def delete_bucket(
Expand All @@ -261,7 +259,7 @@ def delete_bucket(
self.get_bucket(bucket_name).delete_bucket()
except Exception as e:
self.log.error(e)
raise AirflowException(f"Errors when deleting: {bucket_name}")
raise AirflowException(f"Errors when deleting: {bucket_name}") from e

@provide_bucket_name
def create_bucket(
Expand All @@ -277,7 +275,7 @@ def create_bucket(
self.get_bucket(bucket_name).create_bucket()
except Exception as e:
self.log.error(e)
raise AirflowException(f"Errors when create bucket: {bucket_name}")
raise AirflowException(f"Errors when create bucket: {bucket_name}") from e

@provide_bucket_name
@unify_bucket_name_and_key
Expand All @@ -295,7 +293,7 @@ def append_string(self, bucket_name: Optional[str], content: str, key: str, pos:
self.get_bucket(bucket_name).append_object(key, pos, content)
except Exception as e:
self.log.error(e)
raise AirflowException(f"Errors when append string for object: {key}")
raise AirflowException(f"Errors when append string for object: {key}") from e

@provide_bucket_name
@unify_bucket_name_and_key
Expand All @@ -311,7 +309,7 @@ def read_key(self, bucket_name: Optional[str], key: str) -> str:
return self.get_bucket(bucket_name).get_object(key).read().decode("utf-8")
except Exception as e:
self.log.error(e)
raise AirflowException(f"Errors when read bucket object: {key}")
raise AirflowException(f"Errors when read bucket object: {key}") from e

@provide_bucket_name
@unify_bucket_name_and_key
Expand All @@ -327,7 +325,7 @@ def head_key(self, bucket_name: Optional[str], key: str) -> oss2.models.HeadObje
return self.get_bucket(bucket_name).head_object(key)
except Exception as e:
self.log.error(e)
raise AirflowException(f"Errors when head bucket object: {key}")
raise AirflowException(f"Errors when head bucket object: {key}") from e

@provide_bucket_name
@unify_bucket_name_and_key
Expand All @@ -344,36 +342,36 @@ def key_exist(self, bucket_name: Optional[str], key: str) -> bool:
return self.get_bucket(bucket_name).object_exists(key)
except Exception as e:
self.log.error(e)
raise AirflowException(f"Errors when check bucket object existence: {key}")
raise AirflowException(f"Errors when check bucket object existence: {key}") from e

def get_credential(self) -> oss2.auth.Auth:
extra_config = self.oss_conn.extra_dejson
auth_type = extra_config.get('auth_type', None)
if not auth_type:
raise Exception("No auth_type specified in extra_config. ")

if auth_type == 'AK':
oss_access_key_id = extra_config.get('access_key_id', None)
oss_access_key_secret = extra_config.get('access_key_secret', None)
if not oss_access_key_id:
raise Exception("No access_key_id is specified for connection: " + self.oss_conn_id)
if not oss_access_key_secret:
raise Exception("No access_key_secret is specified for connection: " + self.oss_conn_id)
return oss2.Auth(oss_access_key_id, oss_access_key_secret)
else:
raise Exception("Unsupported auth_type: " + auth_type)
if auth_type != 'AK':
raise Exception(f"Unsupported auth_type: {auth_type}")
oss_access_key_id = extra_config.get('access_key_id', None)
oss_access_key_secret = extra_config.get('access_key_secret', None)
if not oss_access_key_id:
raise Exception(f"No access_key_id is specified for connection: {self.oss_conn_id}")

if not oss_access_key_secret:
raise Exception(f"No access_key_secret is specified for connection: {self.oss_conn_id}")

return oss2.Auth(oss_access_key_id, oss_access_key_secret)

def get_default_region(self) -> Optional[str]:
extra_config = self.oss_conn.extra_dejson
auth_type = extra_config.get('auth_type', None)
if not auth_type:
raise Exception("No auth_type specified in extra_config. ")

if auth_type == 'AK':
default_region = extra_config.get('region', None)
if not default_region:
raise Exception("No region is specified for connection: " + self.oss_conn_id)
else:
raise Exception("Unsupported auth_type: " + auth_type)
if auth_type != 'AK':
raise Exception(f"Unsupported auth_type: {auth_type}")

default_region = extra_config.get('region', None)
if not default_region:
raise Exception(f"No region is specified for connection: {self.oss_conn_id}")
return default_region
31 changes: 15 additions & 16 deletions airflow/providers/alibaba/cloud/log/oss_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import contextlib
import os
import pathlib
import sys

if sys.version_info >= (3, 8):
Expand Down Expand Up @@ -61,6 +63,7 @@ def hook(self):
)

def set_context(self, ti):
"""This function is used to set the context of the handler"""
super().set_context(ti)
# Local location and remote location is needed to open and
# upload local log file to OSS remote storage.
Expand Down Expand Up @@ -91,8 +94,7 @@ def close(self):
remote_loc = self.log_relative_path
if os.path.exists(local_loc):
# read log and remove old logs to get just the latest additions
with open(local_loc) as logfile:
log = logfile.read()
log = pathlib.Path(local_loc).read_text()
self.oss_write(log, remote_loc)

# Mark closed so we don't double write if close is called twice
Expand All @@ -114,15 +116,14 @@ def _read(self, ti, try_number, metadata=None):
log_relative_path = self._render_filename(ti, try_number)
remote_loc = log_relative_path

if self.oss_log_exists(remote_loc):
# If OSS remote file exists, we do not fetch logs from task instance
# local machine even if there are errors reading remote logs, as
# returned remote_log will contain error messages.
remote_log = self.oss_read(remote_loc, return_error=True)
log = f'*** Reading remote log from {remote_loc}.\n{remote_log}\n'
return log, {'end_of_log': True}
else:
if not self.oss_log_exists(remote_loc):
return super()._read(ti, try_number)
# If OSS remote file exists, we do not fetch logs from task instance
# local machine even if there are errors reading remote logs, as
# returned remote_log will contain error messages.
remote_log = self.oss_read(remote_loc, return_error=True)
log = f'*** Reading remote log from {remote_loc}.\n{remote_log}\n'
return log, {'end_of_log': True}

def oss_log_exists(self, remote_log_location):
"""
Expand All @@ -131,11 +132,9 @@ def oss_log_exists(self, remote_log_location):
:param remote_log_location: log's location in remote storage
:return: True if location exists else False
"""
oss_remote_log_location = self.base_folder + '/' + remote_log_location
try:
oss_remote_log_location = f'{self.base_folder}/{remote_log_location}'
with contextlib.suppress(Exception):
return self.hook.key_exist(self.bucket_name, oss_remote_log_location)
except Exception:
pass
return False

def oss_read(self, remote_log_location, return_error=False):
Expand All @@ -148,7 +147,7 @@ def oss_read(self, remote_log_location, return_error=False):
error occurs. Otherwise returns '' when an error occurs.
"""
try:
oss_remote_log_location = self.base_folder + '/' + remote_log_location
oss_remote_log_location = f'{self.base_folder}/{remote_log_location}'
self.log.info("read remote log: %s", oss_remote_log_location)
return self.hook.read_key(self.bucket_name, oss_remote_log_location)
except Exception:
Expand All @@ -168,7 +167,7 @@ def oss_write(self, log, remote_log_location, append=True):
:param append: if False, any existing log file is overwritten. If True,
the new log is appended to any existing logs.
"""
oss_remote_log_location = self.base_folder + '/' + remote_log_location
oss_remote_log_location = f'{self.base_folder}/{remote_log_location}'
pos = 0
if append and self.oss_log_exists(oss_remote_log_location):
head = self.hook.head_key(self.bucket_name, oss_remote_log_location)
Expand Down
7 changes: 6 additions & 1 deletion airflow/providers/alibaba/cloud/sensors/oss_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,12 @@ def __init__(
self.hook: Optional[OSSHook] = None

def poke(self, context: 'Context'):

"""
Check if the object exists in the bucket to pull key.
@param self - the object itself
@param context - the context of the object
@returns True if the object exists, False otherwise
"""
if self.bucket_name is None:
parsed_url = urlparse(self.bucket_key)
if parsed_url.netloc == '':
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/alibaba/cloud/hooks/test_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_get_bucket(self, mock_oss2, mock_get_credential):
self.hook.get_bucket('mock_bucket_name')
mock_get_credential.assert_called_once_with()
mock_oss2.Bucket.assert_called_once_with(
mock_get_credential.return_value, 'http://oss-mock_region.aliyuncs.com', MOCK_BUCKET_NAME
mock_get_credential.return_value, 'https://oss-mock_region.aliyuncs.com', MOCK_BUCKET_NAME
)

@mock.patch(OSS_STRING.format('OSSHook.get_bucket'))
Expand Down