I'm attempting to speed up a multivariate fixed-point iteration algorithm using multiprocessing however, I'm running issues dealing with shared data. My solution vector is actually a named dictionary rather than a vector of numbers. Each element of the vector is actually computed using a different formula. At a high level, I have an algorithm like this:
current_estimate = previous_estimate
while True:
for state in all_states:
    current_estimate[state] = state.getValue(previous_estimate)
if norm(current_estimate, previous_estimate) < tolerance:
    break
else:
    previous_estimate, current_estimate = current_estimate, previous_estimate
I'm trying to parallelize the for-loop part with multiprocessing. The previous_estimate variable is read-only and each process only needs to write to one element of current_estimate. My current attempt at rewriting the for-loop is as follows:
# Class and function definitions
class A(object):
    def __init__(self,val):
        self.val = val
    # representative getValue function
    def getValue(self, est):
        return est[self] + self.val
def worker(state, in_est, out_est):
    out_est[state] = state.getValue(in_est)
def worker_star(a_b_c):
    """ Allow multiple arguments for a pool
        Taken from http://stackoverflow.com/a/5443941/3865495
    """
    return worker(*a_b_c)
# Initialize test environment
manager = Manager()
estimates = manager.dict()
all_states = []
for i in range(5):
     a = A(i)
     all_states.append(a)
     estimates[a] = 0
pool = Pool(process = 2)
prev_est = estimates
curr_est = estimates
pool.map(worker_star, itertools.izip(all_states, itertools.repeat(prev_est), itertools.repreat(curr_est)))
The issue I'm currently running into is that the elements added to the all_states array are not the same as those added to the manager.dict(). I keep getting key value errors when trying to access elements of the dictionary using elements of the array. And debugging, I found that none of the elements are the same.
print map(id, estimates.keys())
>>> [19558864, 19558928, 19558992, 19559056, 19559120]
print map(id, all_states)
>>> [19416144, 19416208, 19416272, 19416336, 19416400]
 
     
    