Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Linear Chain CRF layer and a text chunking example #4621

Closed
wants to merge 15 commits into from

Conversation

phipleg
Copy link
Contributor

@phipleg phipleg commented Dec 6, 2016

This pull request relates to issue #4090. It adds a new layer ChainCRF with dedicated loss function and Viterbi style decoding for infering the best tag sequence. To demonstrate the use, an example for text chunking is given as well.

'''
return chain_crf_loss(y_true, y_pred, self.U, self.b)

def sparse_loss(self, y_true, y_pred):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting trick.

@linxihui
Copy link
Contributor

linxihui commented Dec 8, 2016

Hmm, interesting. I also have an implementation, and I am thinking of making a pull request and I saw your PR. My implementation supports masking, Viterbi decoding, computes marginal probability, and allows the CRF to be used as an intermediate layer (instead of last layer only) if using marginal mode (instead of join distribution mode). My PR is here #4646
I am thinking if we could combine both of our code?

@phipleg
Copy link
Contributor Author

phipleg commented Dec 9, 2016

Sorry for the late reply. I'm lying in bed with a bad cold.

Good idea to join forces @linxihui! It is very nice that you don't need further extension of the keras backend. I guess I could refactor my code, to make use of the K.rnn method instead of using theano.scan and tensorflow.scan. Then no extension of the backend would be necessary. However this is not so clear to me, since I do a lot of workarounds to avoid converting the sparse targets to one-hot encoded vectors.

Moreover, I could make masking work if that is really necessary for the first release. In my experience masking makes everything very slow without much benefit. That's why I didn't put much effort in it.

Then, if I understand your implementation correctly, using your CRF as a last layer with learn_mode=join and test_mode=viterbi is almost the same as a dense layer followed by my ChainCRF. However, the latter supports a boundary condition on the left, by providing a bias weight to learn the transition between a virtual start label and the first label of the target sequence. I've seen implementations of linear CRF's, where the input and target sequences are embedded in sequences of length plus two in order to deal with the left and right boundary condition. I didn't handle the end boundary condition since in my application (sentence tagging), the sequences have variable length and are always padded on the right, so that the padding elements act as a virtual end label. You can see the benefit of handling these boundary conditions in my integration tests. There, the test data has the property that the input x_t and the target y_t are independent, except when t=0, and y_t = y_{t-1} + 1 (modulo nb_classes). Nonetheless, the network is able to learn the correct tag sequence (with accuracy >= 0.94), which is only possible when the left boundary condition is handled correctly. Admittedly, this data is quite artificial though and doesn't resemble real data like the one seen in text chunking.

What are the applications of using your CRF layer as an intermediate layer with returning marginal probabilties? Should this functionality be in the same class, or do you see a way to move it to another CRF class?

@fchollet
Copy link
Member

fchollet commented Dec 9, 2016

I guess I could refactor my code, to make use of the K.rnn method instead of using theano.scan and tensorflow.scan. Then no extension of the backend would be necessary.

That is a better approach. Note that K.rnn is completely general and should be sufficient to implement any iterative loop processing.

@linxihui
Copy link
Contributor

linxihui commented Dec 20, 2016

  1. Yes, I did see some implementation has the "boundary" energies, including Baidu's paddle paddle, though not specified in most of the theoretical works. Indeed, I did consider adding it as an option. BTW, I did not understand how you handle mask in CRF when it is added on top of a fixed-length RNN with mask? What if there is a sample with all mask=1 (not actual padding at the end)? and what if there is multiple (said 3) paddings at the end, does your implementation add "multiple virtual" boundary energies? What about left padding?

  2. As to your question about using CRF as an intermediate layer, honestly, I could not actually find an example. I did give a try to stack CRFs on a NER dataset, but it does not turn out to be beneficial (it does not make a lot of sense to have a 'distribution' as hidden layer here). A possible 'theoretical scenario' would be to learn some hidden(or another output) class distribution if the target can be factorized as Pr(y|x) = \sum_z Pr(y|x, z)Pr(z|x) or Pr(y, z|x)=Pr(y|z, x)Pr(z|x), where Pr(z|x) is be modeled as CRF. The forward-backward recursive algorithm acts like a sum-product approach for factor graph. But I don't know if it would work or not and the user can decide it.

@linxihui
Copy link
Contributor

@phipleg Do you agree that we should merge our code? Use my implementation (if you think it could be better option), but with your unit test and example? I have some code for handling boundary energies and I will make a commit after work.

@phipleg
Copy link
Contributor Author

phipleg commented Dec 30, 2016

Hey @linxihui !

I've completed my rewrite the past few days, so that the CRF makes use of K.rnn with only minor additions to the backend (as suggested by @fchollet ) and everything has unit tests as before. Regarding your questions:

  1. The layer doesn't support any kind of masking. Do you need this? As I said, I only handle padding on the left with a single virtual element. Due to variable length sequences, padding on the right is done by non-virtual elements anyway, so there is no need for adding more padding logic.

  2. Would it be ok for you to ommit the use case of placing the CRF as an intermediate layer? I think this would overload the term CRF a bit too much since this is not anymore what the user would expect.

If it is ok for you, I would be happy to leave it now as it is. We can add more functionality with another pull request, ok?

Copy link
Contributor

@linxihui linxihui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@phipleg

  1. For masking, please see my comments in your code. Yes, you can just have right padding, but then you have to highlight on your documentation that left padding is not accepted. People will take it for grant as RNN supports either.
  2. I apologize for mention too much about using CRF as intermediate layer. Indeed, this is just a side effect I discovered after implementing the marginal mode. The marginal mode was originally for computing the marginal probabilities, since Viterbi only gives labels. I don't encourage people use it, but just mention it in case anyone interest in it.

def _forward_step(x_t, states):
alpha_tm1, U_shared = states
B = K.expand_dims(alpha_tm1, 2) + K.expand_dims(x_t, 1) + K.expand_dims(U_shared, 0)
alpha_t = K.logsumexp(B, axis=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With not mask, how do you handle a batch with different sequence lengths? Said a batch of 2, with first input have lenght 2 and the 2nd has a lenght 4. Then you would do a right padding on the 1st sequence to make it have length 4. In you implementation without mask, for the 1st input, the true energy should be
b_start + x1' y1 + x2' y2 + y1' U y2 + b_end
but in your implementation, there isn't a b_end, but with an additional
x3' y3' + x4' y4 + y2' U y3 + y3' U y4,
where y3 = y4 = 0.
The consequence is, the above two formulations are not equivalent, at least when you take derivative with respect to U_00 (top-left element in matrix U), the derivative isn't the same. Right? (also, U_00 and U_11 are not exchangable, but why we treat label 0 and label 1 differently?)

Also when you compute the normalization constance (free energy in your code), you have to integrate over y3, y4 (which are paddings). I guess that's what you mean by "padding elements act as a virtual end label". However, if when you think about taking derivative with respect to U or b_end, your approach is not equivalent to a real CRF.

One very obvious observation is, y3, y4, the padding, affects the derivative with respect to U, and therefore, the paddings plays a role on the final outcome. The more paddings you have, the more impact the paddings affects the outcome. This is unexpected from my point of view.

Lastly, another simple observation, a model with and without the end energy (b_end), the numbers of trainable parameters are not the same. So the two models are not the same.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for these remarks. You are completely right, that the models are not the same.

I don't handle batches of different length explicitely. They are embedded in sequences of fixed length by design, so the user has to do the conversion himself.

@nreimers
Copy link

Hi,
thanks for the code.

I had a problem when saving the model with keras model.save(). It threw a exception, that an h5py object could not be initialized.

The bug is located in the crf.py on line 267 & 268:
self.b_start = K.zeros((n_classes, ), name='{}_b_start'.format(self.name))
self.b_end = K.zeros((n_classes, ), name='{}_b_start'.format(self.name))

Change it like this
self.b_start = K.zeros((n_classes, ), name='{}_b_start'.format(self.name))
self.b_end = K.zeros((n_classes, ), name='{}_b_end'.format(self.name))

@nreimers
Copy link

I think deserialization using keras.models.load_model() isn't working yet in the this implementation.

When I load the model like

crf = ChainCRF()
model = keras.models.load_model('model.h5', custom_objects={'ChainCRF': crf, 'sparse_loss': crf.sparse_loss})

I get the error during the model.compile() call, that 'self.U' does not exist in the sparse_loss-Function. The reason is, that when the model is constructed from the saved model, a new CRF object is created. However, the passed 'sparse_loss' function is from a different CRF-object, not the one that the reconstructed model is using and therefore, the parameters U, start_b and end_b are not set.

Any ideas how to fix this / how to pass the custom_objects properly so that I can load a stored model?

@phipleg
Copy link
Contributor Author

phipleg commented Jan 23, 2017

Hi @nreimers,

Thank you very much for pointing out the bug in the definition of b_end! It is fixed.

Regarding your question about model loading, I wrote a method create_custom_objects that creates the correct custom objects. You can see it in action in the new test test_persistence.

@nreimers
Copy link

Thanks for the quick reply / quick fix.

Just as general note:
I tested this CRF implementation as well as the CRF implementation in #4646 on the CoNLL 2003 English NER dataset with a bidirectional LSTM and found out, that his implementation performs superior (sometimes ~1% more F1-measure). The implementation presented in #4646 has issues with the BIO encoding and produces a larger number of ill-formatted tags, i.e. tags start with an I- tag without a B-tag before.

@phipleg
Copy link
Contributor Author

phipleg commented Jan 24, 2017

Hi @nreimers ,

just to clarify: You say that the implementation #4648 is superior, or did you mean this (and not his)?

@nreimers
Copy link

Hi @phipleg,
sorry when I was inprecise.

With this implementation, I achieve an F1 score of about 0.89 on CoNLL 2003 NER.
With the implementation of #4646, the F1 score is at about 0.88, roughly 1% below of this implementation.

This implementation produces between 0 - 3 wrong tags on the dev/test data, i.e. an I-tag starts without a previous B-tag.

With the implementation of #4646, the number of wrong BIO tags are between 20 - 60 on dev/test set.

In both cases I use the word embeddings by Levy et al., a 100-dim. Bi-LSTM, a dense hidden layer with linear activation function and the CRF.
The implementation of #4646 has a dense hidden layer already included in the CRF layer, so #4646 works on the output of the Bi-LSTM.

@tboquet
Copy link
Contributor

tboquet commented Jan 31, 2017

@phipleg you could consider opening another PR (in parallel) in the contrib repo 🎉 ! There are more official reviewers so you could have more feedback.

@phipleg
Copy link
Contributor Author

phipleg commented Feb 1, 2017

Hi @nreimers,

Thanks a lot for this detailed test. You probably used the english dataset, right?

@phipleg
Copy link
Contributor Author

phipleg commented Feb 1, 2017

Hi @tboquet,

Thanks a lot for your suggestion. I will do that tomorrow evening.

@nreimers
Copy link

nreimers commented Feb 1, 2017

Hi @phipleg,
correct, I used the English CoNLL 2003 dataset (which is sadly not freely available). State of the art is at about 0.90 F1-measure on that dataset.

@napsternxg
Copy link

Hi @phipleg thanks for adding the ChainCRF layer. However, I think it is breaking functionality of keras when using sample_weight_mode="temporal". Here is an example which succeeds:

import numpy as np

from keras import backend as K
from keras.models import Sequential
from keras.layers import Dense, ChainCRF, TimeDistributed

nb_classes=3
X = np.random.randn(10,5,11)
y = np.random.randint(0,nb_classes,size=(10,5))
y = np.expand_dims(y, -1)
y_mask = np.random.randint(0,2,size=(10,))
X.shape, y.shape, y_mask.shape

model = Sequential()
model.add(TimeDistributed(Dense(nb_classes),
                          input_shape=(5,11), name="temporal_dense"))
crf = ChainCRF()
model.add(crf)
model.summary()

model.compile(loss=crf.sparse_loss, optimizer='sgd')
model.fit(X, y, sample_weight=y_mask, nb_epoch=1)

However, when I edit the following modification of the above code which uses temporal sample weights fails:

y_mask = np.random.randint(0,2,size=(10,5))
model = Sequential()
model.add(TimeDistributed(Dense(nb_classes, #activation='softmax'
                               ),
                          input_shape=(5,11), name="temporal_dense"))
crf = ChainCRF()
model.add(crf)
model.summary()
model.compile(loss=crf.sparse_loss, optimizer='sgd', sample_weight_mode='temporal')
model.fit(X, y, sample_weight=y_mask, nb_epoch=1)

I get the following error:

Incompatible shapes: [10] vs. [10,5]
	 [[Node: mul_2 = Mul[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/gpu:0"](Mean, _recv_chaincrf_1_sample_weights_0/_19)]]
	 [[Node: while_1/Identity_1/_75 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/cpu:0", send_device="/job:localhost/replica:0/task:0/gpu:0", send_device_incarnation=1, tensor_name="edge_1073_while_1/Identity_1", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/cpu:0"](^_cloopwhile_1/TensorArrayReadV2/_1)]]

Will it be appropriate to consider the sample weight per time step in the CRF loss or should the sample weighing be done at the output level?

@napsternxg
Copy link

@phipleg also it might be helpful to utilize the tensorflow API for CRF loss, as described in https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/crf/python/ops/crf.py
They have a good API usage example at: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/crf

The tensorflow API also uses the masks per time_step item (the masks are created based on sequence lengths, currently).

I am not sure how to go about it, but in the current approach the is a requirement for holding the CRF object for the ChainCRF layer, in order to define the loss function using crf.sparse_loss. If there is a way to do it without holding the CRF object that would make integrating it with the current keras API much easier.

'''
y_true = K.cast(y_true, 'int32')
y_true = K.squeeze(y_true, 2)
return sparse_chain_crf_loss(y_true, y_pred, self.U, self.b_start, self.b_end)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should fetch the mask in the sparse_loss function as well, the same way you do in loss function.
Also, all the loss should be the mean loss per batch, right now you are simply doing the sum of the losses. This makes the losses for different batch sizes to be of different scales.

@phipleg
Copy link
Contributor Author

phipleg commented Feb 5, 2017

Hi @napsternxg!

Thank you very much for your review! I added the missing mask in crf.sparse_loss.

However, I couldn't fix the problem with the temporal sample weights. The problem is, that usually, the loss functions return a value for each batch element and each time step, i.e. return a tensor of shape (batch_size, maxlen, ...). The total loss with calculating means and sample weighting is done later on in keras.engine.training. But the loss function of a chain CRF is by design reducing along the time axis, so we cannot get loss value for each time step in a canonical way. It used to return a tensor of shape (batch_size, ) and I changed this to be (batch_size, 1) (by reshaping). Now you can at least apply a scalar weight on each sample.

Also, thank you for pointing out the CRF implementation in TensorFlow. Unfortunately, I don't see a way to integrate it in Keras. The main problem is that the loss functions in Keras are usually pure functions depending only on the predicted and the true values: They cannot depend on trainable weights. Currently, I have no better answer.

@napsternxg
Copy link

@phipleg thanks for the update. I think in that case masking is the best option for getting the loss without padding.

@nreimers
Copy link

Hi,
I try to apply this layer to non-fixed-lengthed data, i.e. I do not pad the sentences to a common length and every batch does have a different sentence length (most easiest case: performing online training and just training sentence by sentence, every sentence with a different length).

The first issue is a minor one: In line 289 of crf.py is the following code:
assert n_steps >= 2

This assertion fails when the number of steps is not defined a priori. Removing this assertion solves the issue.

The second issue is Tensorflow specific (i.e. the CRF layer works perfect with Theano). I get the following message when switching to Tensorflow as backend:

 File "/home/username/hyperopt/neuralnets/BiLSTM.py", line 105, in buildModel
    model.compile(loss=lossFct, optimizer=opt)
  File "/home/username/.local/lib/python2.7/site-packages/keras/models.py", line 594, in compile
    **kwargs)
  File "/home/username/.local/lib/python2.7/site-packages/keras/engine/training.py", line 667, in compile
    sample_weight, mask)
  File "/home/username/.local/lib/python2.7/site-packages/keras/engine/training.py", line 318, in weighted
    score_array = fn(y_true, y_pred)
  File "/home/username/hyperopt/keraslayers/ChainCRF.py", line 401, in sparse_loss
    return sparse_chain_crf_loss(y_true, y_pred, self.U, self.b_start, self.b_end, mask)
  File "/home/username/hyperopt/keraslayers/ChainCRF.py", line 116, in sparse_chain_crf_loss
    energy -= free_energy0(x, U, mask)
  File "/home/username/hyperopt/keraslayers/ChainCRF.py", line 182, in free_energy0
    mask)
  File "/home/username/hyperopt/keraslayers/ChainCRF.py", line 204, in _forward
    last, values, _ = K.rnn(_forward_step, inputs, initial_states)
  File "/home/username/.local/lib/python2.7/site-packages/keras/backend/tensorflow_backend.py", line 2141, in rnn
    swap_memory=True)
  File "/home/username/.local/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2636, in while_loop
    result = context.BuildLoop(cond, body, loop_vars, shape_invariants)
  File "/home/username/.local/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2469, in BuildLoop
    pred, body, original_loop_vars, loop_vars, shape_invariants)
  File "/home/username/.local/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2419, in _BuildLoop
    body_result = body(*packed_vars_for_body)
  File "/home/username/.local/lib/python2.7/site-packages/keras/backend/tensorflow_backend.py", line 2130, in _step
    tuple(constants))
  File "/home/username/hyperopt/keraslayers/ChainCRF.py", line 191, in _forward_step
    new_states = reduce_step(K.expand_dims(alpha_tm1, 2) + energy_matrix_t)
  File "/home/username/.local/lib/python2.7/site-packages/keras/backend/tensorflow_backend.py", line 1606, in expand_dims
    return tf.expand_dims(x, dim)
  File "/home/username/.local/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", line 138, in expand_dims
    return gen_array_ops._expand_dims(input, axis, name)
  File "/home/username/.local/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 985, in _expand_dims
    result = _op_def_lib.apply_op("ExpandDims", input=input, dim=dim, name=name)
  File "/home/username/.local/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 759, in apply_op
    op_def=op_def)
  File "/home/username/.local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2240, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "/home/username/.local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1128, in __init__
    self._traceback = _extract_stack()

InvalidArgumentError (see above for traceback): Tried to expand dim index 2 for tensor with 1 dimensions.
	 [[Node: while_4/ExpandDims = ExpandDims[T=DT_FLOAT, Tdim=DT_INT32, _device="/job:localhost/replica:0/task:0/cpu:0"](while_4/Identity_2, while_4/ExpandDims/dim)]]

I use the current Tensorflow 0.12.1 and Keras in version 1.2.1.

My model looks something like this:

model = Sequential()
model.add(Embedding(input_dim=embeddings.shape[0], output_dim=embeddings.shape[1],  weights=[embeddings], trainable=False))
model.add(LSTM(100, return_sequences=True))
crf = ChainCRF()
model.add(crf) 

As the length of my sentences & mini batches varies, I omit the input_length=maxlen parameter for the Embedding layer. With Theano it works perfect, however, with Tensorflow I get the above error.

Philipp Gross added 2 commits February 13, 2017 22:57
@phipleg
Copy link
Contributor Author

phipleg commented Feb 13, 2017

Hi @nreimers,

I'am very glad that you test the layer so thoroughly! If you need any help, please let me know.

The problems when working with mini batches of size 1 should be fixed (by upgrading K.logsumexp in tensorflow_backend).

Happy testing

@ipoletaev
Copy link

When approximately we will see the appearance of the final version of CRF layer in the master branch?

@phipleg
Copy link
Contributor Author

phipleg commented Feb 16, 2017

As suggested by @tboquet I opened the parallel PR keras-team/keras-contrib#25 in order to speed up the reviewing process! 🎉

@kaya27
Copy link

kaya27 commented Mar 3, 2017

Hello @nreimers , @phipleg thanks for the code!
I'm training my for mini batches of varying sentence length (batch size=1) using train_on_batch, the code is very well running however I face a problem when (sentence_length=1)

The error value is:
ValueError: ('Sequence is shorter then the required number of steps : (n_steps, seq, seq.shape):', 1, CudaNdarray([]), (0, 1, 1, 436))
Apply node that caused the error: forall_inplace,gpu,scan_fn}(Elemwise{maximum,no_inplace}.0, GpuDimShuffle{0,1,x,2}.0, GpuIncSubtensor{InplaceSet;:int64:}.0, GpuDimShuffle{x,0,1}.0)
Toposort index: 1045
Inputs types: [TensorType(int64, scalar), CudaNdarrayType(float32, (False, False, True, False)), CudaNdarrayType(float32, 3D), CudaNdarrayType(float32, (True, False, False))]
Inputs shapes: [(), (0, 1, 1, 436), (3, 1, 436), (1, 436, 436)]
Inputs strides: [(), (0, 0, 0, 1), (436, 0, 1), (0, 436, 1)]
Inputs values: [array(1), CudaNdarray([]), 'not shown', 'not shown']
Outputs clients: [[GpuSubtensor{int64:int64:int64}(forall_inplace,gpu,scan_fn}.0, ScalarFromTensor.0, ScalarFromTensor.0, Constant{-1}), GpuSubtensor{int64}(forall_inplace,gpu,scan_fn}.0, ScalarFromTensor.0)]]

regarding I've replaced n_steps >= 2
by:
assert n_steps == None or n_steps >= 2

Thanks in advance

@djstrong
Copy link
Contributor

djstrong commented Mar 6, 2017

I have tried it with 1000 outputs and got memory error:
MemoryError: Error allocating 53561380864 bytes of device memory (CNMEM_STATUS_OUT_OF_MEMORY).
However output with flag exception_verbosity=high shows:
TotalSize: 463608342.0 Byte(s) 0.432 GB
TotalSize inputs: 145905774.0 Byte(s) 0.136 GB

Why it tries to allocate much more memory?

@fchollet
Copy link
Member

Closing outdated PR. If you still care about the content of the PR, please submit a new PR to master, updated for the Keras 2.0 API.

@fchollet fchollet closed this Mar 15, 2017
@nreimers
Copy link

I would love to see this layer included in Keras 2.0

I use it for several NLP task and in all tasks it shows a strong performance increase in comparison to a softmax classifier. For me this layer is a must have if you do sequence tagging, e.g. for sequence tagging for NLP.

@phipleg
Copy link
Contributor Author

phipleg commented Mar 17, 2017

Sorry for the late reply, life got in the way. Thanks for testing! I will work on another PR soon, in order to give the CRF a change to get into Keras 2.0.

@kaya27: I introduced an error while handling the step size problem, consider this fixed.

@djstrong: I haven't investigated your problem yet. At some point the current implementation converts sparse outputs to dense ones as an intermediate step. This might be related to your problem.

@harryhaos
Copy link

Is there paper about your spare loss definition? I can't really understand it……

@mtmvu
Copy link

mtmvu commented Mar 25, 2017

@phipleg
If you need help for updating your crf layer to be keras-2 compliant, please let me know

@phipleg
Copy link
Contributor Author

phipleg commented Mar 26, 2017

Dear @harryhaos,

The spare loss is defined in https://arxiv.org/pdf/1603.01360.pdf .

@phipleg
Copy link
Contributor Author

phipleg commented Mar 26, 2017

Dear @mtmvu,

The update is almost done (see https://github.com/phipleg/keras/tree/crf), but can't fix the failing tests and the persistence workaround doesn't work anymore. If your time permits, I would be happy you could join me!

Best and happy weekend.

@zhhongzhi
Copy link

@phipleg
Thank you for your contribution. I wonder whether the crf layer is available in Keras 2.
from keras.layers import ChainCRF
raise an ImportError
Traceback (most recent call last):
File "", line 1, in
ImportError: cannot import name ChainCRF

Or, whether this version is ok?
https://github.com/phipleg/keras

@kaya27
Copy link

kaya27 commented Apr 12, 2017

hi @zhhongzhi
it should work if the library is well installed. I guess you had a previous keras version
uninstall keras and try again:
python setup.py install --force

@xtknight
Copy link

xtknight commented Apr 22, 2017

To those of you for whom it was not so obvious how to install this:
(First uninstall normal keras)
$ sudo pip3 uninstall keras
$ git clone https://github.com/phipleg/keras.git

must be crf branch!!!

$ git checkout crf
$ sudo python3 setup.py build
$ sudo python3 setup.py install

Any chance to get this merged into master keras? It would really be great.

Is it possible to use CRF objective with a FFNN instead of RNN? Just wondering if there was some inherent limitation because all the examples I've ever seen use RNN. Just adding the normal Sequence layer and then ChainCRF will work?

@utkrist
Copy link

utkrist commented May 5, 2017

Does this only work with Theano backend or tensorflow as well?

@xtknight
Copy link

@utkrist I've been using a project with ChainCRF and I've switched between both backends and both seem to work well.

@jxwb088047
Copy link

jxwb088047 commented Jun 27, 2017

@nreimers hello,I am a freshman in NER, can you share your code about NER code with CRF layer on Conll 2003 dataset to help understand of ?Thx

@nreimers
Copy link

@jxwb088047 I will publish my code soon on Github (beginning to mid July). I will let you now as soon as I pushed it to public github..

@jxwb088047
Copy link

@nreimers I am waiting for your great code, and following your update.

@nreimers
Copy link

@jxwb088047 Hi, I uploaded my BiLSTM-(CNN)-CRF code here:
https://github.com/UKPLab/emnlp2017-bilstm-cnn-crf

It can be used to train the models from Huang et al (BiLSTM-CRF), from Ma & Hovy (BiLSTM-CNN-CRF) and from Lample et al. (BiLSTM-LSTM-CRF).

I hope the documentation and the code is helpful enough to give you a good start into this topic.

The code uses CRF code by phipleg, thank you again for contributing this to the community. However, it currently works only with Keras 1.x. Maybe we can update at some point to work with Keras 2.x

@jxwb088047
Copy link

@nreimers Thank you for sharing your codes and relevant materials, I will dig into all your share materials and hope to comprehensively understanding of NER.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.