Skip to content

Commit

Permalink
Modernizing MySqlOperator + unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
mistercrunch committed Mar 10, 2015
1 parent 8efa9f8 commit 9574863
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 42 deletions.
10 changes: 1 addition & 9 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ TODO
* Backfill wizard

#### unittests
* Increase coverage, now 70ish%
* Increase coverage, now 80ish%

#### Command line
* `airflow task_state dag_id task_id YYYY-MM-DD`
Expand All @@ -14,25 +14,17 @@ TODO
* S3Sensor
* BaseDataTransferOperator
* File2MySqlOperator
* PythonOperator
* DagTaskSensor for cross dag dependencies
* PIG

#### Macros
* Previous execution timestamp
* Previous ds
* ...

#### Frontend
*

#### Backend
* Callbacks
* Master auto dag refresh at time intervals
* Set default args at the DAG level?
* Prevent timezone chagne on import
* Add decorator to timeout imports on master process [lib](https://github.com/pnpnpn/timeout-decorator)
* Mysql port should carry through (using default now)
* Make authentication universal

#### Misc
Expand Down
40 changes: 11 additions & 29 deletions airflow/hooks/mysql_hook.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,25 @@
import MySQLdb
from airflow import settings
from airflow.models import Connection

from airflow.hooks.base_hook import BaseHook

class MySqlHook(object):

class MySqlHook(BaseHook):
'''
Interact with MySQL.
'''

def __init__(
self, host=None, login=None,
psw=None, db=None, mysql_conn_id=None):
if not mysql_conn_id:
self.host = host
self.login = login
self.psw = psw
self.db = db
else:
session = settings.Session()
db = session.query(
Connection).filter(
Connection.conn_id == mysql_conn_id)
if db.count() == 0:
raise Exception("The mysql_dbid you provided isn't defined")
else:
db = db.all()[0]
self.host = db.host
self.login = db.login
self.psw = db.password
self.db = db.schema
session.commit()
session.close()
self, mysql_conn_id=None):

conn = self.get_connection(mysql_conn_id)
self.conn = conn

def get_conn(self):
conn = MySQLdb.connect(
self.host,
self.login,
self.psw,
self.db)
self.conn.host,
self.conn.login,
self.conn.password,
self.conn.schema)
return conn

def get_records(self, sql):
Expand Down
2 changes: 1 addition & 1 deletion airflow/operators/mysql_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class MySqlOperator(BaseOperator):
ui_color = '#ededed'

@apply_defaults
def __init__(self, sql, mysql_conn_id, *args, **kwargs):
def __init__(self, sql, mysql_conn_id='mysql_default', *args, **kwargs):
super(MySqlOperator, self).__init__(*args, **kwargs)

self.hook = MySqlHook(mysql_conn_id=mysql_conn_id)
Expand Down
2 changes: 1 addition & 1 deletion airflow/www/templates/airflow/dag.html
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ <h4 class="modal-title" id="myModalLabel">
</button>
<hr/>
<button id="btn_clear" type="button" class="btn btn-primary">
Clear
<span class="glyphicon glyphicon-trash" aria-hidden="true"></span> Clear
</button>
<span class="btn-group">
<button id="btn_past"
Expand Down
23 changes: 21 additions & 2 deletions tests/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
DEFAULT_DATE = datetime(2015, 1, 1)
configuration.test_mode()


class HivePrestoTest(unittest.TestCase):

def setUp(self):
Expand Down Expand Up @@ -227,5 +226,25 @@ def tearDown(self):
pass


class MySqlTest(unittest.TestCase):

def setUp(self):
configuration.test_mode()
utils.initdb()
args = {'owner': 'airflow', 'start_date': datetime(2015, 1, 1)}
dag = DAG('hive_test', default_args=args)
self.dag = dag

def mysql_operator_test(self):
sql = """
CREATE TABLE IF NOT EXISTS test_airflow (
dummy VARCHAR(50)
);
"""
t = operators.MySqlOperator(
task_id='basic_mysql', sql=sql, dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, force=True)


if __name__ == '__main__':
unittest.main()
unittest.main()

0 comments on commit 9574863

Please sign in to comment.