Skip to content

Commit 88bc694

Browse files
committed
numpy: properly parse tns formatted tensors
I was previously doing some sort of handrolled parsing for these, when I should have just looked at how TACO parses the tensors.
1 parent 9a89021 commit 88bc694

File tree

1 file changed

+22
-35
lines changed

1 file changed

+22
-35
lines changed

numpy/util.py

+22-35
Original file line numberDiff line numberDiff line change
@@ -2,66 +2,52 @@
22
import sparse
33
import os
44

5-
# CoordinateListFileLoader loads a file in coordinate list
6-
# format into a list of the coordinates and a list of the values.
7-
class CoordinateListFileLoader:
5+
# TnsFileLoader loads a tensor stored in .tns format.
6+
class TnsFileLoader:
87
def __init__(self):
98
pass
10-
9+
1110
def load(self, path):
12-
dims = []
13-
entries = None
1411
coordinates = []
1512
values = []
13+
dims = []
1614
first = True
1715
with open(path, 'r') as f:
1816
for line in f:
19-
# Skip lines with %, as some downloaded files have these
20-
# at the header as comments.
21-
if line.startswith("%"):
22-
continue
2317
data = line.split(' ')
24-
coords = [int(coord) for coord in data[:len(data) - 1]]
18+
coords = [int(coord) - 1 for coord in data[:len(data) - 1]]
2519
# TODO (rohany): What if we want this to be an integer?
2620
value = float(data[-1])
27-
# If this is the first line being read, then the read
28-
# coordinates and values are actually the size of each
29-
# dimension and the number of non-zeros.
3021
if first:
31-
dims = coords
32-
entries = int(value)
3322
first = False
34-
else:
35-
coordinates.append(coords)
36-
values.append(value)
37-
assert(len(coordinates) == entries)
38-
assert(len(values) == entries)
23+
dims = [0] * len(coords)
24+
for i in range(len(coords)):
25+
dims[i] = max(dims[i], coords[i] + 1)
26+
coordinates.append(coords)
27+
values.append(value)
3928
return dims, coordinates, values
4029

41-
# CoordinateListFileDumper dumps a dictionary of coordinates to values
30+
# TnsFileDumper dumps a dictionary of coordinates to values
4231
# into a coordinate list tensor file.
43-
class CoordinateListFileDumper:
32+
class TnsFileDumper:
4433
def __init__(self):
4534
pass
4635

4736
def dump_dict_to_file(self, shape, data, path):
4837
# Sort the data so that the output is deterministic.
4938
sorted_data = sorted([list(coords) + [value] for coords, value in data.items()])
5039
with open(path, 'w+') as f:
51-
# Write the metadata into the file.
52-
dims = list(shape) + [len(data)]
53-
f.write(" ".join([str(elem) for elem in dims]))
54-
f.write("\n")
5540
for line in sorted_data:
56-
strings = [str(elem) for elem in line]
41+
coords = [str(elem + 1) for elem in line[:len(line) - 1]]
42+
strings = coords + [str(line[-1])]
5743
f.write(" ".join(strings))
5844
f.write("\n")
5945

6046
# ScipySparseTensorLoader loads a sparse tensor from a file into a
6147
# scipy.sparse CSR matrix.
6248
class ScipySparseTensorLoader:
6349
def __init__(self, format):
64-
self.loader = CoordinateListFileLoader()
50+
self.loader = TnsFileLoader()
6551
self.format = format
6652

6753
def load(self, path):
@@ -84,7 +70,7 @@ def load(self, path):
8470
# a pydata.sparse tensor.
8571
class PydataSparseTensorLoader:
8672
def __init__(self):
87-
self.loader = CoordinateListFileLoader()
73+
self.loader = TnsFileLoader()
8874

8975
def load(self, path):
9076
dims, coords, values = self.loader.load(path)
@@ -99,13 +85,13 @@ def load(self, path):
9985
# a random tensor parameterized by the chosen shape and sparsity.
10086
# The key itself is formatted by the dimensions, followed by the
10187
# sparsity. For example, a 250 by 250 tensor with sparsity 0.01
102-
# would have a key of 250x250-0.01.tensor.
88+
# would have a key of 250x250-0.01.tns.
10389
def construct_random_tensor_key(shape, sparsity):
10490
# Get the path to the directory holding random tensors. Error out
10591
# if this isn't set.
10692
path = os.environ['TACO_RANDOM_TENSOR_PATH']
10793
dims = "x".join([str(dim) for dim in shape])
108-
key = "{}-{}.tensor".format(dims, sparsity)
94+
key = "{}-{}.tns".format(dims, sparsity)
10995
return os.path.join(path, key)
11096

11197
# RandomPydataSparseTensorLoader should be used to generate
@@ -126,14 +112,15 @@ def random(self, shape, sparsity):
126112
# dump it to the output file, then return it.
127113
result = sparse.random(shape, density=sparsity)
128114
dok = sparse.DOK(result)
129-
CoordinateListFileDumper().dump_dict_to_file(shape, dok.data, key)
115+
TnsFileDumper().dump_dict_to_file(shape, dok.data, key)
130116
return result
131117

132118
# RandomScipySparseTensorLoader is the same as RandomPydataSparseTensorLoader
133119
# but for scipy.sparse tensors.
134120
class RandomScipySparseTensorLoader:
135121
def __init__(self, format):
136122
self.loader = ScipySparseTensorLoader(format)
123+
self.format = format
137124

138125
def random(self, shape, sparsity):
139126
assert(len(shape) == 2)
@@ -143,7 +130,7 @@ def random(self, shape, sparsity):
143130
return self.loader.load(key)
144131
else:
145132
# Otherwise, create and then dump a tensor.
146-
result = scipy.sparse.random(shape[0], shape[1], density=sparsity, format='csr')
133+
result = scipy.sparse.random(shape[0], shape[1], density=sparsity, format=self.format)
147134
dok = scipy.sparse.dok_matrix(result)
148-
CoordinateListFileDumper().dump_dict_to_file(shape, dict(dok.items()), key)
135+
TnsFileDumper().dump_dict_to_file(shape, dict(dok.items()), key)
149136
return result

0 commit comments

Comments
 (0)