forked from HumanSignal/label-studio-ml-backend
-
Notifications
You must be signed in to change notification settings - Fork 0
/
api.py
147 lines (112 loc) · 3.76 KB
/
api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import logging
import os
from flask import Flask, request, jsonify
from .model import LabelStudioMLBase
from .exceptions import exception_handler
from .cache import create_cache, BaseCache
logger = logging.getLogger(__name__)
_server = Flask(__name__)
cache = BaseCache
MODEL_CLASS = LabelStudioMLBase
def init_app(model_class):
global MODEL_CLASS, cache
if not issubclass(model_class, LabelStudioMLBase):
raise ValueError('Inference class should be the subclass of ' + LabelStudioMLBase.__class__.__name__)
MODEL_CLASS = model_class
cache = create_cache(
os.getenv('CACHE_TYPE', 'sqlite'),
path=os.getenv('MODEL_DIR', '.'))
return _server
@_server.route('/predict', methods=['POST'])
@exception_handler
def _predict():
"""
Predict tasks
Example request:
request = {
'tasks': tasks,
'model_version': model_version,
'project': '{project.id}.{int(project.created_at.timestamp())}',
'label_config': project.label_config,
'params': {
'login': project.task_data_login,
'password': project.task_data_password,
'context': context,
},
}
@return:
Predictions in LS format
"""
data = request.json
tasks = data.get('tasks')
params = data.get('params') or {}
project_id = data.get('project').split('.', 1)[0]
label_config = data.get('label_config')
context = params.pop('context', {})
model = MODEL_CLASS(project_id, cache)
model.use_label_config(label_config)
predictions = model.predict(tasks, context=context, **params)
return jsonify({'results': predictions})
@_server.route('/setup', methods=['POST'])
@exception_handler
def _setup():
data = request.json
project_id = data.get('project').split('.', 1)[0]
label_config = data.get('schema')
model = MODEL_CLASS(project_id, cache)
model.use_label_config(label_config)
model_version = model.get('model_version')
return jsonify({'model_version': model_version})
TRAIN_EVENTS = (
'ANNOTATION_CREATED',
'ANNOTATION_UPDATED',
'ANNOTATION_DELETED',
'PROJECT_UPDATED'
)
@_server.route('/webhook', methods=['POST'])
def webhook():
data = request.json
event = data.pop('action')
if event not in TRAIN_EVENTS:
return jsonify({'status': 'Unknown event'}), 200
project_id = str(data['project']['id'])
label_config = data['project']['label_config']
model = MODEL_CLASS(project_id, cache)
model.use_label_config(label_config)
model.fit(event, data)
return jsonify({}), 201
@_server.route('/health', methods=['GET'])
@_server.route('/', methods=['GET'])
@exception_handler
def health():
return jsonify({
'status': 'UP',
'model_class': MODEL_CLASS.__name__,
'cache_type': cache.__class__.__name__
})
@_server.route('/metrics', methods=['GET'])
@exception_handler
def metrics():
return jsonify({})
@_server.errorhandler(FileNotFoundError)
def file_not_found_error_handler(error):
logger.warning('Got error: ' + str(error))
return str(error), 404
@_server.errorhandler(AssertionError)
def assertion_error(error):
logger.error(str(error), exc_info=True)
return str(error), 500
@_server.errorhandler(IndexError)
def index_error(error):
logger.error(str(error), exc_info=True)
return str(error), 500
@_server.before_request
def log_request_info():
logger.debug('Request headers: %s', request.headers)
logger.debug('Request body: %s', request.get_data())
@_server.after_request
def log_response_info(response):
logger.debug('Response status: %s', response.status)
logger.debug('Response headers: %s', response.headers)
logger.debug('Response body: %s', response.get_data())
return response