tensor_utils

Functions

add_dummy_dims_left(x, y)[source]
add_dummy_dims_right(x, y)[source]
broadcast(x, y, ignored_dims)[source]
Parameters:

ignored_dims (List[int])

broadcast_all(xs, ignored_dims=[])[source]
broadcast_both(x, y, ignored_dims)[source]
Parameters:

ignored_dims (List[int])

broadcast_both_left(x, y, ignored_dims)[source]
Parameters:

ignored_dims (List[int])

broadcast_cross_product(x, y, dim=-1)[source]
broadcast_gather(src, dim, ind, keepdim=False, **kwargs)[source]
Parameters:

dim (int)

broadcast_interleave(x, counts, inds, dim=-2)[source]
broadcast_scatter(input, dim, ind, src, **kwargs)[source]
cast_to_tensor(x)[source]

Converts scalars or lists of scalars into tensors, and combines lists of tensors into a single tensor. All other input types are returned unchanged. Returned tensors are always of shape [1,N,D] where D is dimension and N is the number of tensors combined.

cast_to_tensor_single(x)[source]
concat_dicts(kwargs)[source]

Concatenates a list of dicts sharing the same keys, the resulting dictionary has the same keys and concatenated values.

dot_product(x, y, dim=-1, keepdim=True, out=None)[source]
dot_product_in_place(x, y, dim=-1)[source]
expand_as_left(x, y, offset=0)[source]
Parameters:

offset (int)

expand_as_right(x, y, offset=0)[source]
Parameters:

offset (int)

implements(torch_function)[source]

Register a torch function override for ScalarTensor

interpolate(x, y, a)[source]
make_grid(height, width=None, min_coord=-1, max_coord=1, min_coord2=None, max_coord2=None)[source]
mean(xs)[source]
mid_point(x, dim=-1, keepdim=True)[source]
offset(x)[source]
pack_tensor(x, packing)[source]
packed_reorder(x, counts, ids)[source]
pad_dim_left(x, num_dims)[source]
pad_dim_right(x, num_dims)[source]
pad_dims(xs, unsqueeze_dim=-2)[source]
prepare_kwargs(self, func, args, kwargs, initial_args, unique_args)[source]

Combine args and kwargs into one dict, using default values where arg is missing

reduce_max_score(x, scores, dim=-1)[source]
robust_concat(xs)[source]

Concatenates multiple tensors together while broadcasting as necessary to ensure shapes match.

scatter_arg_max(x, inds, dim=-1, dim_size=None)[source]
shuffle(x)[source]
squish(x, start=0, end=1)[source]
Parameters:
  • start (int)

  • end (int)

unpack_tensor(x, packing)[source]
unsqueeze_dims(x, y, insert_dim=0)[source]
unsqueeze_left(x, y, offset=0)[source]
Parameters:

offset (int)

unsqueeze_pack_tensors(xs, packing)[source]
unsqueeze_right(x, y, offset=0)[source]
Parameters:

offset (int)

unsqueeze_until_dim(x, dim, insert_dim=0)[source]
unsquish(x, dim=0, factor=None)[source]
Parameters:
  • dim (int)

  • factor (int | None)