from numpy import unravel_index, prod, arange, asarray, float64
from itertools import product
from bolt.construct import ConstructBase
from bolt.spark.array import BoltArraySpark
from bolt.spark.utils import get_kv_shape, get_kv_axes
[docs]class ConstructSpark(ConstructBase):
@staticmethod
[docs] def array(a, context=None, axis=(0,), dtype=None):
"""
Create a spark bolt array from a local array.
Parameters
----------
a : array-like
An array, any object exposing the array interface, an
object whose __array__ method returns an array, or any
(nested) sequence.
context : SparkContext
A context running Spark. (see pyspark)
axis : tuple, optional, default=(0,)
Which axes to distribute the array along. The resulting
distributed object will use keys to represent these axes,
with the remaining axes represented by values.
dtype : data-type, optional, default=None
The desired data-type for the array. If None, will
be determined from the data. (see numpy)
Returns
-------
BoltArraySpark
"""
if dtype is None:
arry = asarray(a)
dtype = arry.dtype
else:
arry = asarray(a, dtype)
shape = arry.shape
ndim = len(shape)
# handle the axes specification and transpose if necessary
axes = ConstructSpark._format_axes(axis, arry.shape)
key_axes, value_axes = get_kv_axes(arry.shape, axes)
permutation = key_axes + value_axes
arry = arry.transpose(*permutation)
split = len(axes)
if split < 1:
raise ValueError("split axis must be greater than 0, got %g" % split)
if split > len(shape):
raise ValueError("split axis must not exceed number of axes %g, got %g" % (ndim, split))
key_shape = shape[:split]
val_shape = shape[split:]
keys = zip(*unravel_index(arange(0, int(prod(key_shape))), key_shape))
vals = arry.reshape((prod(key_shape),) + val_shape)
rdd = context.parallelize(zip(keys, vals))
return BoltArraySpark(rdd, shape=shape, split=split, dtype=dtype)
@staticmethod
[docs] def ones(shape, context=None, axis=(0,), dtype=float64):
"""
Create a spark bolt array of ones.
Parameters
----------
shape : tuple
The desired shape of the array.
context : SparkContext
A context running Spark. (see pyspark)
axis : tuple, optional, default=(0,)
Which axes to distribute the array along. The resulting
distributed object will use keys to represent these axes,
with the remaining axes represented by values.
dtype : data-type, optional, default=float64
The desired data-type for the array. If None, will
be determined from the data. (see numpy)
Returns
-------
BoltArraySpark
"""
from numpy import ones
return ConstructSpark._wrap(ones, shape, context, axis, dtype)
@staticmethod
[docs] def zeros(shape, context=None, axis=(0,), dtype=float64):
"""
Create a spark bolt array of zeros.
Parameters
----------
shape : tuple
The desired shape of the array.
context : SparkContext
A context running Spark. (see pyspark)
axis : tuple, optional, default=(0,)
Which axes to distribute the array along. The resulting
distributed object will use keys to represent these axes,
with the remaining axes represented by values.
dtype : data-type, optional, default=float64
The desired data-type for the array. If None, will
be determined from the data. (see numpy)
Returns
-------
BoltArraySpark
"""
from numpy import zeros
return ConstructSpark._wrap(zeros, shape, context, axis, dtype)
@staticmethod
[docs] def concatenate(arrays, axis=0):
"""
Join two bolt arrays together, at least one of which is in spark.
Parameters
----------
arrays : tuple
A pair of arrays. At least one must be a spark array,
the other can be a local bolt array, a local numpy array,
or an array-like.
axis : int, optional, default=0
The axis along which the arrays will be joined.
Returns
-------
BoltArraySpark
"""
if not isinstance(arrays, tuple):
raise ValueError("data type not understood")
if not len(arrays) == 2:
raise NotImplementedError("spark concatenation only supports two arrays")
first, second = arrays
if isinstance(first, BoltArraySpark):
return first.concatenate(second, axis)
elif isinstance(second, BoltArraySpark):
first = ConstructSpark.array(first, second._rdd.context)
return first.concatenate(second, axis)
else:
raise ValueError("at least one array must be a spark bolt array")
@staticmethod
def _argcheck(*args, **kwargs):
"""
Check that arguments are consistent with spark array construction.
Conditions are:
(1) a positional argument is a SparkContext
(2) keyword arg 'context' is a SparkContext
(3) an argument is a BoltArraySpark, or
(4) an argument is a nested list containing a BoltArraySpark
"""
try:
from pyspark import SparkContext
except ImportError:
return False
cond1 = any([isinstance(arg, SparkContext) for arg in args])
cond2 = isinstance(kwargs.get('context', None), SparkContext)
cond3 = any([isinstance(arg, BoltArraySpark) for arg in args])
cond4 = any([any([isinstance(sub, BoltArraySpark) for sub in arg])
if isinstance(arg, (tuple, list)) else False for arg in args])
return cond1 or cond2 or cond3 or cond4
@staticmethod
def _format_axes(axes, shape):
"""
Format target axes given an array shape
"""
if isinstance(axes, int):
axes = (axes,)
elif isinstance(axes, list) or hasattr(axes, '__iter__'):
axes = tuple(axes)
if not isinstance(axes, tuple):
raise ValueError("axes argument %s in the constructor not specified correctly" % str(axes))
if min(axes) < 0 or max(axes) > len(shape) - 1:
raise ValueError("invalid key axes %s given shape %s" % (str(axes), str(shape)))
return axes
@staticmethod
def _wrap(func, shape, context=None, axis=(0,), dtype=None):
"""
Wrap an existing numpy constructor in a parallelized construction
"""
if isinstance(shape, int):
shape = (shape,)
key_shape, value_shape = get_kv_shape(shape, ConstructSpark._format_axes(axis, shape))
split = len(key_shape)
# make the keys
rdd = context.parallelize(list(product(*[arange(x) for x in key_shape])))
# use a map to make the arrays in parallel
rdd = rdd.map(lambda x: (x, func(value_shape, dtype, order='C')))
return BoltArraySpark(rdd, shape=shape, split=split, dtype=dtype)