Skip to content

Commit 9f18d9c

Browse files
author
liuyu
committed
support A-softmax
1 parent 0d3b5db commit 9f18d9c

File tree

6 files changed

+906
-0
lines changed

6 files changed

+906
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#ifndef CAFFE_MARGIN_INNER_PRODUCT_LAYER_HPP_
2+
#define CAFFE_MARGIN_INNER_PRODUCT_LAYER_HPP_
3+
4+
#include <vector>
5+
6+
#include "caffe/blob.hpp"
7+
#include "caffe/layer.hpp"
8+
#include "caffe/proto/caffe.pb.h"
9+
10+
namespace caffe {
11+
12+
/**
13+
* @brief Also known as a "marginal fully-connected" layer, computes an marginal inner product
14+
* with a set of learned weights, and (optionally) adds biases.
15+
*
16+
* TODO(dox): thorough documentation for Forward, Backward, and proto params.
17+
*/
18+
template <typename Dtype>
19+
class MarginInnerProductLayer : public Layer<Dtype> {
20+
public:
21+
explicit MarginInnerProductLayer(const LayerParameter& param)
22+
: Layer<Dtype>(param) {}
23+
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
24+
const vector<Blob<Dtype>*>& top);
25+
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
26+
const vector<Blob<Dtype>*>& top);
27+
28+
virtual inline const char* type() const { return "MarginInnerProduct"; }
29+
virtual inline int ExactNumBottomBlobs() const { return 2; }
30+
virtual inline int MaxTopBlobs() const { return 2; }
31+
32+
protected:
33+
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
34+
const vector<Blob<Dtype>*>& top);
35+
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
36+
const vector<Blob<Dtype>*>& top);
37+
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
38+
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
39+
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
40+
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
41+
42+
int M_;
43+
int K_;
44+
int N_;
45+
46+
MarginInnerProductParameter_MarginType type_;
47+
48+
// common variables
49+
Blob<Dtype> x_norm_;
50+
Blob<Dtype> cos_theta_;
51+
Blob<Dtype> sign_0_; // sign_0 = sign(cos_theta)
52+
// for DOUBLE type
53+
Blob<Dtype> cos_theta_quadratic_;
54+
// for TRIPLE type
55+
Blob<Dtype> sign_1_; // sign_1 = sign(abs(cos_theta) - 0.5)
56+
Blob<Dtype> sign_2_; // sign_2 = sign_0 * (1 + sign_1) - 2
57+
Blob<Dtype> cos_theta_cubic_;
58+
// for QUADRA type
59+
Blob<Dtype> sign_3_; // sign_3 = sign_0 * sign(2 * cos_theta_quadratic_ - 1)
60+
Blob<Dtype> sign_4_; // sign_4 = 2 * sign_0 + sign_3 - 3
61+
Blob<Dtype> cos_theta_quartic_;
62+
63+
int iter_;
64+
Dtype lambda_;
65+
66+
};
67+
68+
} // namespace caffe
69+
70+
#endif // CAFFE_MAEGIN_INNER_PRODUCT_LAYER_HPP_

0 commit comments

Comments
 (0)