|
1 | 1 | import scipy.sparse
|
2 | 2 | import sparse
|
3 | 3 | import os
|
| 4 | +import glob |
| 5 | + |
| 6 | +# Get the path to the directory holding random tensors. Error out |
| 7 | +# if this isn't set. |
| 8 | +TENSOR_PATH = os.environ['TACO_TENSOR_PATH'] |
4 | 9 |
|
5 | 10 | # TnsFileLoader loads a tensor stored in .tns format.
|
6 | 11 | class TnsFileLoader:
|
@@ -87,12 +92,10 @@ def load(self, path):
|
87 | 92 | # sparsity. For example, a 250 by 250 tensor with sparsity 0.01
|
88 | 93 | # would have a key of 250x250-0.01.tns.
|
89 | 94 | def construct_random_tensor_key(shape, sparsity):
|
90 |
| - # Get the path to the directory holding random tensors. Error out |
91 |
| - # if this isn't set. |
92 |
| - path = os.environ['TACO_RANDOM_TENSOR_PATH'] |
| 95 | + path = TENSOR_PATH |
93 | 96 | dims = "x".join([str(dim) for dim in shape])
|
94 | 97 | key = "{}-{}.tns".format(dims, sparsity)
|
95 |
| - return os.path.join(path, key) |
| 98 | + return os.path.join(path, "random", key) |
96 | 99 |
|
97 | 100 | # RandomPydataSparseTensorLoader should be used to generate
|
98 | 101 | # random pydata.sparse tensors. It caches the loaded tensors
|
@@ -134,3 +137,27 @@ def random(self, shape, sparsity):
|
134 | 137 | dok = scipy.sparse.dok_matrix(result)
|
135 | 138 | TnsFileDumper().dump_dict_to_file(shape, dict(dok.items()), key)
|
136 | 139 | return result
|
| 140 | + |
| 141 | +# FROSTTTensor represents a tensor in the FROSTT dataset. |
| 142 | +class FROSTTTensor: |
| 143 | + def __init__(self, path): |
| 144 | + self.path = path |
| 145 | + |
| 146 | + def __str__(self): |
| 147 | + f = os.path.split(self.path)[1] |
| 148 | + return f.replace(".tns", "") |
| 149 | + |
| 150 | + def load(self): |
| 151 | + return PydataSparseTensorLoader().load(self.path) |
| 152 | + |
| 153 | +# TensorCollectionFROSTT represents the set of all FROSTT tensors. |
| 154 | +class TensorCollectionFROSTT: |
| 155 | + def __init__(self): |
| 156 | + data = os.path.join(TENSOR_PATH, "FROSTT") |
| 157 | + frostttensors = glob.glob(os.path.join(data, "*.tns")) |
| 158 | + self.tensors = [FROSTTTensor(t) for t in frostttensors] |
| 159 | + |
| 160 | + def getTensors(self): |
| 161 | + return self.tensors |
| 162 | + def getTensorNames(self): |
| 163 | + return [str(tensor) for tensor in self.getTensors()] |
0 commit comments