Skip to content

Commit 4dbfb57

Browse files
committed
merge Dan's PR about Array2
1 parent 37aafd8 commit 4dbfb57

File tree

4 files changed

+68
-96
lines changed

4 files changed

+68
-96
lines changed

k2/csrc/array.h

+26-96
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,18 @@
1414

1515
namespace k2 {
1616

17-
1817
/*
19-
We will use e.g. StridedPtr<int32_t, T> when the stride is not 1, and otherwise
20-
just T* (which presumably be faster).
18+
We will use e.g. StridedPtr<int32_t, T> when the stride is not 1, and
19+
otherwise just T* (which presumably be faster).
2120
*/
2221
template <typename I, typename T>
23-
class StridedPtr {
22+
struct StridedPtr {
2423
T *data;
2524
I stride;
26-
T &operator [] (I i) { return data[i]; }
27-
StridedPtr(T *data, I stride): data(data), stride(stride) { }
25+
T &operator[](I i) { return data[i]; }
26+
StridedPtr(T *data, I stride) : data(data), stride(stride) {}
2827
};
2928

30-
31-
3229
/* MIGHT NOT NEED THIS */
3330
template <typename I, typename Ptr>
3431
struct Array1 {
@@ -37,23 +34,20 @@ struct Array1 {
3734
using IndexT = I;
3835
using PtrT = Ptr;
3936

40-
4137
// 'begin' and 'end' are the first and one-past-the-last indexes into `data`
4238
// that we are allowed to use.
4339
IndexT begin;
4440
IndexT end;
4541

4642
PtrT data;
47-
48-
4943
};
5044

51-
5245
/*
5346
This struct stores the size of an Array2 object; it will generally be used as
5447
an output argument by functions that work out this size.
5548
*/
56-
template <typename I> struct Array2Size {
49+
template <typename I>
50+
struct Array2Size {
5751
using IndexT = I;
5852
// `size1` is the top-level size of the array, equal to the object's .size
5953
// element
@@ -62,8 +56,7 @@ template <typename I> struct Array2Size {
6256
// o->indexes[o->size] - o->indexes[0] (if the Array2 object o is
6357
// initialized).
6458
I size2;
65-
}
66-
59+
};
6760

6861
template <typename I, typename Ptr>
6962
struct Array2 {
@@ -79,113 +72,50 @@ struct Array2 {
7972
// not required that indexes[0] == 0, it may be
8073
// greater than 0.
8174

82-
PtrT data; // `data` might be an actual pointer, or might be some object
83-
// supporting operator []. data[indexes[0]] through
84-
// data[indexes[size] - 1] must be accessible through this
85-
// object.
75+
PtrT data; // `data` might be an actual pointer, or might be some object
76+
// supporting operator []. data[indexes[0]] through
77+
// data[indexes[size] - 1] must be accessible through this
78+
// object.
8679

8780
/* initialized definition:
8881
8982
An Array2 object is initialized if its `size` member is set and its
9083
`indexes` and `data` pointer allocated, and the values of its `indexes`
9184
array are set for indexes[0] and indexes[size].
9285
*/
93-
9486
};
9587

96-
9788
template <typename I, typename Ptr>
9889
struct Array3 {
99-
using IndexT = I;
100-
using PtrT = Ptr;
101-
102-
// Irregular three dimensional array of something, like vector<vector<vetor<X> > >
103-
// where Ptr is or behaves like X*.
90+
// Irregular three dimensional array of something, like vector<vector<vetor<X>
91+
// > > where Ptr is or behaves like X*.
10492
using IndexT = I;
10593
using PtrT = Ptr;
10694

10795
IndexT size;
108-
const IndexT *indexes1; // indexes1[0,1,...size] should be defined; note, this
109-
// means the array must be of at least size+1. We
110-
// require that indexes[i] <= indexes[i+1], but it is
111-
// not required that indexes[0] == 0, it may be
96+
const IndexT *indexes1; // indexes1[0,1,...size] should be defined; note,
97+
// this means the array must be of at least size+1.
98+
// We require that indexes[i] <= indexes[i+1], but it
99+
// is not required that indexes[0] == 0, it may be
112100
// greater than 0.
113101

114102
const IndexT *indexes2; // indexes2[indexes1[0]]
115103
// .. indexes2[indexes1[size]-1] should be defined.
116104

117-
Ptr data; // `data` might be an actual pointer, or might be some object
118-
// supporting operator []. data[indexes[0]] through
119-
// data[indexes[size] - 1] must be accessible through this
120-
// object.
121-
105+
Ptr data; // `data` might be an actual pointer, or might be some object
106+
// supporting operator []. data[indexes[0]] through
107+
// data[indexes[size] - 1] must be accessible through this
108+
// object.
122109

123-
Array2 operator [] (I i) {
124-
// ...
110+
Array2<I, Ptr> operator[](I i) {
111+
// TODO(haowen): fill real data here
112+
Array2<I, Ptr> array;
113+
return array;
125114
}
126-
127-
};
128-
129-
130-
131-
132-
133-
134-
// we'd put the following in fsa.h
135-
using Cfsa = Array2<int32_t, Arc>;
136-
using CfsaVec = Array3<int32_t, Arc>;
137-
138-
139-
140-
141-
class FstInverter {
142-
/* Constructor. Lightweight. */
143-
FstInverter(const Fsa &fsa_in, const AuxLabels &labels_in);
144-
145-
/*
146-
Do enough work that know now much memory will be needed, and output
147-
that information
148-
@param [out] fsa_size The num-states and num-arcs of the FSA
149-
will be written to here
150-
@param [out] aux_size The number of lists in the AuxLabels
151-
output (==num-arcs) and the number of
152-
elements will be written to here.
153-
*/
154-
void GetSizes(Array2Size<int32_t> *fsa_size,
155-
Array2Size<int32_t> *aux_size);
156-
157-
/*
158-
Finish the operation and output inverted FSA to `fsa_out` and
159-
auxiliary labels to `labels_out`.
160-
@param [out] fsa_out The inverted FSA will be written to
161-
here. Must be initialized; search for
162-
'initialized definition' in class Array2
163-
in array.h for meaning.
164-
@param [out] labels_out The auxiliary labels will be written to
165-
here. Must be initialized; search for
166-
'initialized definition' in class Array2
167-
in array.h for meaning.
168-
*/
169-
void GetOutput(Fsa *fsa_out,
170-
AuxLabels *labels_out);
171-
private:
172-
// ...
173115
};
174116

175117
// Note: we can create Array4 later if we need it.
176118

177-
178-
void InvertFst(const Fsa &fsa_in, const AuxLabels &labels_in, Fsa *fsa_out,
179-
AuxLabels *aux_labels_out) {
180-
181-
182-
/*
183-
void RmEpsilonsPrunedMax(const WfsaWithFbWeights &a, float beam, Fsa *b,
184-
std::vector<std::vector<int32_t>> *arc_derivs);
185-
*/
186-
187-
188-
189119
} // namespace k2
190120

191121
#endif // K2_CSRC_ARRAY_H_

k2/csrc/aux_labels.h

+37
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include <vector>
1111

12+
#include "k2/csrc/array.h"
1213
#include "k2/csrc/fsa.h"
1314
#include "k2/csrc/fsa_util.h"
1415
#include "k2/csrc/properties.h"
@@ -46,6 +47,9 @@ struct AuxLabels {
4647
std::vector<int32_t> labels;
4748
};
4849

50+
// TODO(haowen): replace AuxLabels above with below definition
51+
using AuxLabels_ = Array2<int32_t, int32_t>;
52+
4953
// Swap AuxLabels; it's cheap to to this as we are actually doing shallow swap.
5054
void Swap(AuxLabels *labels1, AuxLabels *labels2);
5155

@@ -95,6 +99,39 @@ void MapAuxLabels2(const AuxLabels &labels_in,
9599
void InvertFst(const Fsa &fsa_in, const AuxLabels &labels_in, Fsa *fsa_out,
96100
AuxLabels *aux_labels_out);
97101

102+
class FstInverter {
103+
/* Constructor. Lightweight. */
104+
FstInverter(const Fsa &fsa_in, const AuxLabels &labels_in);
105+
106+
/*
107+
Do enough work that know now much memory will be needed, and output
108+
that information
109+
@param [out] fsa_size The num-states and num-arcs of the FSA
110+
will be written to here
111+
@param [out] aux_size The number of lists in the AuxLabels
112+
output (==num-arcs) and the number of
113+
elements will be written to here.
114+
*/
115+
void GetSizes(Array2Size<int32_t> *fsa_size, Array2Size<int32_t> *aux_size);
116+
117+
/*
118+
Finish the operation and output inverted FSA to `fsa_out` and
119+
auxiliary labels to `labels_out`.
120+
@param [out] fsa_out The inverted FSA will be written to
121+
here. Must be initialized; search for
122+
'initialized definition' in class Array2
123+
in array.h for meaning.
124+
@param [out] labels_out The auxiliary labels will be written to
125+
here. Must be initialized; search for
126+
'initialized definition' in class Array2
127+
in array.h for meaning.
128+
*/
129+
void GetOutput(Fsa *fsa_out, AuxLabels *labels_out);
130+
131+
private:
132+
// ...
133+
};
134+
98135
} // namespace k2
99136

100137
#endif // K2_CSRC_AUX_LABELS_H_

k2/csrc/fsa.h

+5
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <vector>
1414

1515
#include "glog/logging.h"
16+
#include "k2/csrc/array.h"
1617
#include "k2/csrc/util.h"
1718

1819
namespace k2 {
@@ -129,6 +130,10 @@ struct Fsa {
129130
}
130131
};
131132

133+
// TODO(haowen): replace Cfsa and CfsaVec with below definitions
134+
using Cfsa_ = Array2<int32_t, Arc>;
135+
using CfsaVec_ = Array3<int32_t, Arc>;
136+
132137
/*
133138
Cfsa is a 'const' FSA, which we'll use as the input to operations. It is
134139
designed in such a way that the storage underlying it may either be an Fsa

notes/array.py notes/array.txt

File renamed without changes.

0 commit comments

Comments
 (0)