Skip to content

Commit

Permalink
GUYS I KNOW HOW TO MULTITHREAD :SNAKE:
Browse files Browse the repository at this point in the history
  • Loading branch information
pjreddie committed Jul 11, 2017
1 parent 59ed171 commit 616e630
Show file tree
Hide file tree
Showing 18 changed files with 319 additions and 151 deletions.
9 changes: 7 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
GPU=0
CUDNN=0
OPENCV=0
OPENMP=0
DEBUG=0

ARCH= -gencode arch=compute_20,code=[sm_20,sm_21] \
Expand All @@ -19,13 +20,17 @@ EXEC=darknet
OBJDIR=./obj/

CC=gcc
NVCC=nvcc --compiler-options '-fPIC'
NVCC=nvcc
AR=ar
ARFLAGS=rcs
OPTS=-Ofast
LDFLAGS= -lm -pthread
COMMON= -Iinclude/ -Isrc/
CFLAGS=-Wall -Wfatal-errors -fPIC
CFLAGS=-Wall -Wno-unknown-pragmas -Wfatal-errors -fPIC

ifeq ($(OPENMP), 1)
COMMON+= -fopenmp
endif

ifeq ($(DEBUG), 1)
OPTS=-O0 -g
Expand Down
2 changes: 1 addition & 1 deletion cfg/imagenet1k.data
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
classes=1000
train = /data/imagenet/imagenet1k.train.list
valid = /data/imagenet/imagenet1k.train.list
valid = /data/imagenet/imagenet1k.valid.list
backup = /home/pjreddie/backup/
labels = data/imagenet.labels.list
names = data/imagenet.shortnames.list
Expand Down
7 changes: 3 additions & 4 deletions examples/classifier.c
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ void train_classifier(char *datacfg, char *cfgfile, char *weightfile, int *gpus,
load_args args = {0};
args.w = net.w;
args.h = net.h;
args.threads = 32;
args.threads = 64;
args.hierarchy = net.hierarchy;

args.min = net.min_crop;
Expand Down Expand Up @@ -670,7 +670,6 @@ void predict_classifier(char *datacfg, char *cfgfile, char *weightfile, char *fi
int *indexes = calloc(top, sizeof(int));
char buff[256];
char *input = buff;
int size = net.w;
while(1){
if(filename){
strncpy(input, filename, 256);
Expand All @@ -682,8 +681,8 @@ void predict_classifier(char *datacfg, char *cfgfile, char *weightfile, char *fi
strtok(input, "\n");
}
image im = load_image_color(input, 0, 0);
image r = resize_min(im, size);
resize_network(&net, r.w, r.h);
image r = letterbox_image(im, net.w, net.h);
//resize_network(&net, r.w, r.h);
//printf("%d %d\n", r.w, r.h);

float *X = r.data;
Expand Down
2 changes: 1 addition & 1 deletion examples/coco.c
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ void test_coco(char *cfgfile, char *weightfile, char *filename, float thresh)
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
get_detection_boxes(l, 1, 1, thresh, probs, boxes, 0);
if (nms) do_nms_sort(boxes, probs, l.side*l.side*l.n, l.classes, nms);
draw_detections(im, l.side*l.side*l.n, thresh, boxes, probs, coco_classes, alphabet, 80);
draw_detections(im, l.side*l.side*l.n, thresh, boxes, probs, 0, coco_classes, alphabet, 80);
save_image(im, "prediction");
show_image(im, "predictions");
free_image(im);
Expand Down
38 changes: 19 additions & 19 deletions examples/detector.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,8 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
//int N = plist->size;
char **paths = (char **)list_to_array(plist);

load_args args = {0};
args.w = net.w;
args.h = net.h;
load_args args = get_base_args(net);
args.coords = l.coords;
args.paths = paths;
args.n = imgs;
args.m = plist->size;
Expand All @@ -52,13 +51,9 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
args.num_boxes = l.max_boxes;
args.d = &buffer;
args.type = DETECTION_DATA;
//args.type = INSTANCE_DATA;
args.threads = 8;

args.angle = net.angle;
args.exposure = net.exposure;
args.saturation = net.saturation;
args.hue = net.hue;

pthread_t load_thread = load_data(args);
clock_t time;
int count = 0;
Expand Down Expand Up @@ -102,7 +97,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i
image im = float_to_image(net.w, net.h, 3, train.X.vals[zz]);
int k;
for(k = 0; k < l.max_boxes; ++k){
box b = float_to_box(train.y.vals[zz] + k*5);
box b = float_to_box(train.y.vals[zz] + k*5, 1);
printf("%f %f %f %f\n", b.x, b.y, b.w, b.h);
draw_bbox(im, b, 1, 1,0,0);
}
Expand Down Expand Up @@ -130,7 +125,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i

i = get_current_batch(net);
printf("%ld: %f, %f avg, %f rate, %lf seconds, %d images\n", get_current_batch(net), loss, avg_loss, get_current_rate(net), sec(clock()-time), i*imgs);
if(i%1000==0){
if(i%100==0){
#ifdef GPU
if(ngpus != 1) sync_nets(nets, ngpus, 0);
#endif
Expand Down Expand Up @@ -342,7 +337,7 @@ void validate_detector_flip(char *datacfg, char *cfgfile, char *weightfile, char
network_predict(net, input.data);
int w = val[t].w;
int h = val[t].h;
get_region_boxes(l, w, h, net.w, net.h, thresh, probs, boxes, 0, map, .5, 0);
get_region_boxes(l, w, h, net.w, net.h, thresh, probs, boxes, 0, 0, map, .5, 0);
if (nms) do_nms_sort(boxes, probs, l.w*l.h*l.n, classes, nms);
if (coco){
print_cocos(fp, path, boxes, probs, l.w*l.h*l.n, classes, w, h);
Expand Down Expand Up @@ -473,7 +468,7 @@ void validate_detector(char *datacfg, char *cfgfile, char *weightfile, char *out
network_predict(net, X);
int w = val[t].w;
int h = val[t].h;
get_region_boxes(l, w, h, net.w, net.h, thresh, probs, boxes, 0, map, .5, 0);
get_region_boxes(l, w, h, net.w, net.h, thresh, probs, boxes, 0, 0, map, .5, 0);
if (nms) do_nms_sort(boxes, probs, l.w*l.h*l.n, classes, nms);
if (coco){
print_cocos(fp, path, boxes, probs, l.w*l.h*l.n, classes, w, h);
Expand Down Expand Up @@ -537,7 +532,7 @@ void validate_detector_recall(char *cfgfile, char *weightfile)
image sized = resize_image(orig, net.w, net.h);
char *id = basecfg(path);
network_predict(net, sized.data);
get_region_boxes(l, sized.w, sized.h, net.w, net.h, thresh, probs, boxes, 1, 0, .5, 1);
get_region_boxes(l, sized.w, sized.h, net.w, net.h, thresh, probs, boxes, 0, 1, 0, .5, 1);
if (nms) do_nms(boxes, probs, l.w*l.h*l.n, 1, nms);

char labelpath[4096];
Expand Down Expand Up @@ -589,11 +584,11 @@ void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filenam
}
set_batch_network(&net, 1);
srand(2222222);
clock_t time;
double time;
char buff[256];
char *input = buff;
int j;
float nms=.4;
float nms=.3;
while(1){
if(filename){
strncpy(input, filename, 256);
Expand All @@ -615,15 +610,20 @@ void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filenam
box *boxes = calloc(l.w*l.h*l.n, sizeof(box));
float **probs = calloc(l.w*l.h*l.n, sizeof(float *));
for(j = 0; j < l.w*l.h*l.n; ++j) probs[j] = calloc(l.classes + 1, sizeof(float *));
float **masks = 0;
if (l.coords > 4){
masks = calloc(l.w*l.h*l.n, sizeof(float*));
for(j = 0; j < l.w*l.h*l.n; ++j) masks[j] = calloc(l.coords-4, sizeof(float *));
}

float *X = sized.data;
time=clock();
time=what_time_is_it_now();
network_predict(net, X);
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
get_region_boxes(l, im.w, im.h, net.w, net.h, thresh, probs, boxes, 0, 0, hier_thresh, 1);
printf("%s: Predicted in %f seconds.\n", input, what_time_is_it_now()-time);
get_region_boxes(l, im.w, im.h, net.w, net.h, thresh, probs, boxes, masks, 0, 0, hier_thresh, 1);
if (nms) do_nms_obj(boxes, probs, l.w*l.h*l.n, l.classes, nms);
//else if (nms) do_nms_sort(boxes, probs, l.w*l.h*l.n, l.classes, nms);
draw_detections(im, l.w*l.h*l.n, thresh, boxes, probs, names, alphabet, l.classes);
draw_detections(im, l.w*l.h*l.n, thresh, boxes, probs, masks, names, alphabet, l.classes);
if(outfile){
save_image(im, outfile);
}
Expand Down
48 changes: 48 additions & 0 deletions examples/rnn.c
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,54 @@ void test_char_rnn(char *cfgfile, char *weightfile, int num, char *seed, float t
printf("\n");
}

void test_tactic_rnn_multi(char *cfgfile, char *weightfile, int num, float temp, int rseed, char *token_file)
{
char **tokens = 0;
if(token_file){
size_t n;
tokens = read_tokens(token_file, &n);
}

srand(rseed);
char *base = basecfg(cfgfile);
fprintf(stderr, "%s\n", base);

network net = parse_network_cfg(cfgfile);
if(weightfile){
load_weights(&net, weightfile);
}
int inputs = net.inputs;

int i, j;
for(i = 0; i < net.n; ++i) net.layers[i].temperature = temp;
int c = 0;
float *input = calloc(inputs, sizeof(float));
float *out = 0;

while(1){
reset_rnn_state(net, 0);
while((c = getc(stdin)) != EOF && c != 0){
input[c] = 1;
out = network_predict(net, input);
input[c] = 0;
}
for(i = 0; i < num; ++i){
for(j = 0; j < inputs; ++j){
if (out[j] < .0001) out[j] = 0;
}
int next = sample_array(out, inputs);
if(c == '.' && next == '\n') break;
c = next;
print_symbol(c, tokens);

input[c] = 1;
out = network_predict(net, input);
input[c] = 0;
}
printf("\n");
}
}

void test_tactic_rnn(char *cfgfile, char *weightfile, int num, float temp, int rseed, char *token_file)
{
char **tokens = 0;
Expand Down
3 changes: 1 addition & 2 deletions examples/yolo.c
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,7 @@ void test_yolo(char *cfgfile, char *weightfile, char *filename, float thresh)
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
get_detection_boxes(l, 1, 1, thresh, probs, boxes, 0);
if (nms) do_nms_sort(boxes, probs, l.side*l.side*l.n, l.classes, nms);
//draw_detections(im, l.side*l.side*l.n, thresh, boxes, probs, voc_names, alphabet, 20);
draw_detections(im, l.side*l.side*l.n, thresh, boxes, probs, voc_names, alphabet, 20);
draw_detections(im, l.side*l.side*l.n, thresh, boxes, probs, 0, voc_names, alphabet, 20);
save_image(im, "predictions");
show_image(im, "predictions");

Expand Down
11 changes: 8 additions & 3 deletions include/darknet.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ struct layer{
float coord_scale;
float object_scale;
float noobject_scale;
float mask_scale;
float class_scale;
int bias_match;
int random;
Expand Down Expand Up @@ -508,7 +509,7 @@ typedef struct{
} data;

typedef enum {
CLASSIFICATION_DATA, DETECTION_DATA, CAPTCHA_DATA, REGION_DATA, IMAGE_DATA, COMPARE_DATA, WRITING_DATA, SWAG_DATA, TAG_DATA, OLD_CLASSIFICATION_DATA, STUDY_DATA, DET_DATA, SUPER_DATA, LETTERBOX_DATA, REGRESSION_DATA, SEGMENTATION_DATA
CLASSIFICATION_DATA, DETECTION_DATA, CAPTCHA_DATA, REGION_DATA, IMAGE_DATA, COMPARE_DATA, WRITING_DATA, SWAG_DATA, TAG_DATA, OLD_CLASSIFICATION_DATA, STUDY_DATA, DET_DATA, SUPER_DATA, LETTERBOX_DATA, REGRESSION_DATA, SEGMENTATION_DATA, INSTANCE_DATA
} data_type;

typedef struct load_args{
Expand All @@ -530,6 +531,7 @@ typedef struct load_args{
int background;
int scale;
int center;
int coords;
float jitter;
float angle;
float aspect;
Expand Down Expand Up @@ -642,7 +644,7 @@ void save_weights_upto(network net, char *filename, int cutoff);
void load_weights_upto(network *net, char *filename, int start, int cutoff);

void zero_objectness(layer l);
void get_region_boxes(layer l, int w, int h, int netw, int neth, float thresh, float **probs, box *boxes, int only_objectness, int *map, float tree_thresh, int relative);
void get_region_boxes(layer l, int w, int h, int netw, int neth, float thresh, float **probs, box *boxes, float **masks, int only_objectness, int *map, float tree_thresh, int relative);
void free_network(network net);
void set_batch_network(network *net, int b);
image load_image(char *filename, int w, int h, int c);
Expand Down Expand Up @@ -677,13 +679,15 @@ void random_distort_image(image im, float hue, float saturation, float exposure)
void fill_image(image m, float s);
image grayscale_image(image im);
void rotate_image_cw(image im, int times);
double what_time_is_it_now();
image rotate_image(image m, float rad);
void visualize_network(network net);
float box_iou(box a, box b);
void do_nms(box *boxes, float **probs, int total, int classes, float thresh);
data load_all_cifar10();
box_label *read_boxes(char *filename, int *n);
void draw_detections(image im, int num, float thresh, box *boxes, float **probs, char **names, image **labels, int classes);
box float_to_box(float *f, int stride);
void draw_detections(image im, int num, float thresh, box *boxes, float **probs, float **masks, char **names, image **alphabet, int classes);

matrix network_predict_data(network net, data test);
image **load_alphabet();
Expand All @@ -709,6 +713,7 @@ image get_image_from_stream(CvCapture *cap);
void free_image(image m);
float train_network(network net, data d);
pthread_t load_data_in_thread(load_args args);
void load_data_blocking(load_args args);
list *get_paths(char *filename);
void hierarchy_predictions(float *predictions, int n, tree *hier, int only_leaves, int stride);
void change_leaves(tree *t, char *leaf_list);
Expand Down
13 changes: 6 additions & 7 deletions src/blas_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -698,9 +698,6 @@ extern "C" void shortcut_gpu(int batch, int w1, int h1, int c1, float *add, int
int minw = (w1 < w2) ? w1 : w2;
int minh = (h1 < h2) ? h1 : h2;
int minc = (c1 < c2) ? c1 : c2;
assert(w1 == w2);
assert(h1 == h2);
assert(c1 == c2);

int stride = w1/w2;
int sample = w2/w1;
Expand Down Expand Up @@ -892,19 +889,21 @@ __global__ void softmax_tree_kernel(float *input, int spatial, int batch, int st

extern "C" void softmax_tree(float *input, int spatial, int batch, int stride, float temp, float *output, tree hier)
{
//int *tree_groups_size = cuda_make_int_array(hier.group_size, hier.groups);
//int *tree_groups_offset = cuda_make_int_array(hier.group_offset, hier.groups);
int *tree_groups_size = cuda_make_int_array(hier.group_size, hier.groups);
int *tree_groups_offset = cuda_make_int_array(hier.group_offset, hier.groups);
/*
static int *tree_groups_size = 0;
static int *tree_groups_offset = 0;
if(!tree_groups_size){
tree_groups_size = cuda_make_int_array(hier.group_size, hier.groups);
tree_groups_offset = cuda_make_int_array(hier.group_offset, hier.groups);
}
*/
int num = spatial*batch*hier.groups;
softmax_tree_kernel<<<cuda_gridsize(num), BLOCK>>>(input, spatial, batch, stride, temp, output, hier.groups, tree_groups_size, tree_groups_offset);
check_error(cudaPeekAtLastError());
//cuda_free((float *)tree_groups_size);
//cuda_free((float *)tree_groups_offset);
cuda_free((float *)tree_groups_size);
cuda_free((float *)tree_groups_offset);
}

__global__ void softmax_kernel(float *input, int n, int batch, int batch_offset, int groups, int group_offset, int stride, float temp, float *output)
Expand Down
1 change: 0 additions & 1 deletion src/box.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ typedef struct{
float dx, dy, dw, dh;
} dbox;

box float_to_box(float *f, int stride);
float box_rmse(box a, box b);
dbox diou(box a, box b);
box decode_box(box b, box anchor);
Expand Down
4 changes: 4 additions & 0 deletions src/convolutional_layer.c
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,11 @@ void cudnn_convolutional_setup(layer *l)
cudnnSetTensor4dDescriptor(l->dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w);
cudnnSetTensor4dDescriptor(l->normTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, l->out_c, 1, 1);
cudnnSetFilter4dDescriptor(l->weightDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
#if CUDNN_MAJOR >= 6
cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT);
#else
cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION);
#endif
cudnnGetConvolutionForwardAlgorithm(cudnn_handle(),
l->srcTensorDesc,
l->weightDesc,
Expand Down
Loading

0 comments on commit 616e630

Please sign in to comment.