Skip to content

Commit b7838d6

Browse files
authored
Add __metadata_guard__ function to class BlockedParameter (#3623)
1 parent ee2171e commit b7838d6

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

intel_extension_for_pytorch/cpu/tpp/utils/blocked_layout.py

+10
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,16 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
304304
args_data = pytree.tree_map_only(BlockedParameter, lambda x: x._data, args)
305305
return func(*args_data, **kwargs)
306306

307+
@classmethod
308+
def __metadata_guard__(cls, orig_data, other):
309+
return (
310+
orig_data[0] == other[0]
311+
and orig_data[1] == other[1]
312+
and orig_data[2] == other[2]
313+
and orig_data[4] == other[4]
314+
and orig_data[5] == other[5]
315+
)
316+
307317
def __copy__(self):
308318
new_param = BlockedParameter(self._data, requires_grad=self.requires_grad)
309319
for k, v in self.__dict__.items():

0 commit comments

Comments
 (0)