SeComLib
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Macros Pages
secure_recommendations/server.cpp
Go to the documentation of this file.
1 /*
2 SeComLib
3 Copyright 2012-2013 TU Delft, Information Security & Privacy Lab (http://isplab.tudelft.nl/)
4 
5 Contributors:
6 Inald Lagendijk (R.L.Lagendijk@TUDelft.nl)
7 Mihai Todor (todormihai@gmail.com)
8 Thijs Veugen (P.J.M.Veugen@tudelft.nl)
9 Zekeriya Erkin (z.erkin@tudelft.nl)
10 
11 Licensed under the Apache License, Version 2.0 (the "License");
12 you may not use this file except in compliance with the License.
13 You may obtain a copy of the License at
14 
15 http://www.apache.org/licenses/LICENSE-2.0
16 
17 Unless required by applicable law or agreed to in writing, software
18 distributed under the License is distributed on an "AS IS" BASIS,
19 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20 See the License for the specific language governing permissions and
21 limitations under the License.
22 */
30 #include "server.h"
31 
32 namespace SeComLib {
33 namespace SecureRecommendations {
41  Server::Server (const PaillierPublicKey &key): cryptoProvider(key), clientPublicKey(key) {
42  //set the kernel
43  this->kernel = SecureSvm::GetKernel(Utils::Config::GetInstance().GetParameter<std::string>("SecureRecommendations.kernel"));
44 
45  //initialize other parameters
46  this->contentItemCount = Utils::Config::GetInstance().GetParameter<unsigned int>("SecureRecommendations.Server.contentItemCount");
47  this->blindingFactorSize = Utils::Config::GetInstance().GetParameter<unsigned int>("SecureRecommendations.Server.blindingFactorSize");
48  this->modelFileExtension = Utils::Config::GetInstance().GetParameter<std::string>("SecureRecommendations.Server.modelFileExtension");
49 
51  this->medicalRelevanceModelsDirectory = Utils::Config::GetInstance().GetParameter<std::string>("SecureRecommendations.Server.MedicalRelevanceBlock.svmModelsFolder") + Utils::Config::GetInstance().GetParameter<std::string>("SecureRecommendations.kernel") + "/";
52  this->medicalRelevanceClusterCount = Utils::Config::GetInstance().GetParameter<unsigned int>("SecureRecommendations.Server.MedicalRelevanceBlock.clusters");
53 
55  this->safetyModelsDirectory = Utils::Config::GetInstance().GetParameter<std::string>("SecureRecommendations.Server.SafetyBlock.svmModelsFolder") + Utils::Config::GetInstance().GetParameter<std::string>("SecureRecommendations.kernel") + "/";
56 
60  }
61 
67  void Server::SetClient (const std::shared_ptr<const Hub> &client) {
68  this->client = client;
69  }
70 
79  std::cout << "Preprocessing model files data." << std::endl;
80 
83 
84  std::cout << "Finished preprocessing data from " << this->medicalRelevanceSvms.size() << " medical relevance block model files." << std::endl;
85 
88 
89  std::cout << "Finished preprocessing data from " << this->safetySvms.size() << " safety block model files." << std::endl;
90 
92  for (unsigned int i = 0; i < this->contentItemCount; ++i) {
93  this->dummyContentItems.emplace_back((RandomProvider::GetInstance().GetRandomInteger(BigInteger(9)) + 1) * 10);
94  }
95 
96  std::cout << "Finished generating dummy content items." << std::endl;
97 
100  for (unsigned int i = 0; i < this->contentItemCount; ++i) {
101  this->dummyEncryptedPreferenceScores.emplace_back(this->cryptoProvider.EncryptInteger(RandomProvider::GetInstance().GetRandomInteger(BigInteger(9) + 1)));
102  }
103 
104  std::cout << "Finished generating dummy preferences block scores." << std::endl;
105  }
106 
119  void Server::InteractiveSecureDivision (const BigInteger &numerator, SecureSvm::EncryptedVector &denominators) const {
120  std::vector<BigInteger> blindingFactors;
121 
123  for (size_t i = 0; i < denominators.size(); ++i) {
125  blindingFactors.emplace_back(RandomProvider::GetInstance().GetRandomInteger(this->blindingFactorSize) + 1);
126 
127  //blind the encrypted denominators
128  denominators[i] = denominators[i] * blindingFactors.back();
129  }
130 
132  SecurePermutation permutation(denominators.size());
133 
135  permutation.Permute(denominators);
136 
138  this->client.lock()->EvaluateDivision(numerator, denominators);
139 
141  permutation.InvertPermutation(denominators);
142 
144  for (size_t i = 0; i < denominators.size(); ++i) {
145  //debug kernel values
146  //this->client.lock()->DebugValue(denominators[i]);
147 
148  denominators[i] = denominators[i] * blindingFactors[i];
149  }
150  }
151 
166  void Server::GetAccuracyPredictions (Server::EncryptedClusterVotes &clusterVotes, SecureSvm::EncryptedVector &safetyPredictions, const TestDataRow &medicalRelevanceTestData, const std::map<std::string, TestDataRow> &safetyTestData) const {
167  //container for the SVM predictions (both medical relevance and safety)
168  Server::EncryptedSvmValues svmPredictions;
169 
171 
172  for (size_t i = 0; i < this->medicalRelevanceSvms.size(); ++i) {
173  //debug
174  //std::string start = Utils::DateTime::Now();
175  svmPredictions.emplace_back(this->medicalRelevanceSvms[i]->Predict(medicalRelevanceTestData.x, medicalRelevanceTestData.xx, medicalRelevanceTestData.xSquared));
176  //std::cout << "medical" << i << ": "; this->client.lock()->DebugValue(svmPredictions.back());
177  //std::cout << "start: " << start << " end: " << Utils::DateTime::Now() << " SVM: " << i << std::endl;
178  }
179 
181  for (size_t i = 0; i < this->safetySvms.size(); ++i) {
182  std::map<std::string, TestDataRow>::const_iterator safetyTestDataIterator = safetyTestData.find(safetySvms[i]->GetUnsafeClasses());
183  if (safetyTestData.end() != safetyTestDataIterator) {
184  //debug
185  //std::string start = Utils::DateTime::Now();
186  svmPredictions.emplace_back(this->safetySvms[i]->Predict((*safetyTestDataIterator).second.x, (*safetyTestDataIterator).second.xx, (*safetyTestDataIterator).second.xSquared));
187  //std::cout << safetyTestDataIterator->first << ": "; this->client.lock()->DebugValue(svmPredictions.back());
188  //std::cout << "start: " << start << " end: " << Utils::DateTime::Now() << " SVM: " << i << " nSV: " << this->safetySvms[i]->model->l << std::endl;
189  }
190  else {
192  throw std::runtime_error("Missing safety test data for unsafe classes: " + safetySvms[i]->GetUnsafeClasses());
193  }
194  }
195 
197 
198  //overwrite the svmPredictions with the data received from the client
199  this->interactiveSignEvaluation(svmPredictions);
200 
201  //debug
202  /*
203  for (size_t i = 0; i < svmPredictions.size(); ++i) {
204  this->client.lock()->DebugValue(svmPredictions[i]);
205  }
206  */
207 
208  //extract and remove the safety SVM predictions from the svmPredictions vector
209  std::move(svmPredictions.begin() + this->medicalRelevanceSvms.size(), svmPredictions.end(), std::back_inserter(safetyPredictions));
210  svmPredictions.erase(svmPredictions.begin() + this->medicalRelevanceSvms.size(), svmPredictions.end());
211 
213  clusterVotes = this->getTotalClusterVotes(svmPredictions);
214 
217 
218  //overwrite the clusterVotes variable
219  this->interactiveMaximumEvaluation(clusterVotes);
220  }
221 
233  void Server::GetPerformancePredictions (SecureSvm::EncryptedVector &firstTwoBlocksPredictions, SecureSvm::EncryptedVector &safetyPredictions, const TestDataRow &medicalRelevanceTestData, const std::map<std::string, TestDataRow> &safetyTestData) const {
234  Server::EncryptedClusterVotes clusterVotes;
235 
237  this->GetAccuracyPredictions(clusterVotes, safetyPredictions, medicalRelevanceTestData, safetyTestData);
238 
240  for (unsigned long i = 0; i < this->contentItemCount; ++i) {
241  //initialize the accumulator
242  Paillier::Ciphertext result = this->encryptedZero;
243 
244  for (unsigned long j = 0; j < this->medicalRelevanceClusterCount; ++j) {
245  //multiply the content items with the cluster votes and add the products
246  result = result + clusterVotes[j] * this->dummyContentItems[i];
247  }
248 
249  //combine the medical relevance and the preferences block
250  firstTwoBlocksPredictions.emplace_back(result + this->dummyEncryptedPreferenceScores[i]);
251  }
252  }
253 
257  std::vector<std::string> Server::GetSafetyBlockSvmsUnsafeClasses () const {
258  std::vector<std::string> output;
259 
260  for (size_t i = 0; i < this->safetySvms.size(); ++i) {
261  output.emplace_back(this->safetySvms[i]->GetUnsafeClasses());
262  }
263 
264  return output;
265  }
266 
270  std::deque<std::string> Server::GetSafetyBlockModelFiles () const {
271  return this->safetySvmModelFiles;
272  }
273 
277  void Server::DebugValue (const Paillier::Ciphertext &value) const {
278  this->client.lock()->DebugValue(value);
279  }
280 
288  void Server::loadMedicalRelevanceSvmModels (const std::string &modelsDirectory) {
289  for (unsigned int i = 0; i < this->medicalRelevanceClusterCount; ++i) {
290  for (unsigned int j = i + 1; j < this->medicalRelevanceClusterCount; ++j) {
291  std::stringstream fileNameBuilder;
292 
293  //construct the file path
294  fileNameBuilder << "cluster" << (i + 1) << "v" << (j + 1) << "." + this->modelFileExtension;
295 
296  //std::unique_ptr ensures that the SecureSvm objects will not get passed by value
297  this->medicalRelevanceSvms.emplace_back(std::unique_ptr<SecureSvm>(new SecureSvm(modelsDirectory, fileNameBuilder.str(), this->clientPublicKey, shared_from_this())));
298  }
299  }
300  }
301 
305  void Server::loadSafetySvmModels (const std::string &modelsDirectory) {
307 
309  if (safetySvmModelFiles.empty()) {
310  throw std::runtime_error("No model files found in directory: " + modelsDirectory);
311  }
312 
313  for (std::deque<std::string>::const_iterator fileIterator = safetySvmModelFiles.begin(); fileIterator != safetySvmModelFiles.end(); ++fileIterator) {
314  //std::unique_ptr ensures that the SecureSvm objects will not get passed by value
315  this->safetySvms.emplace_back(std::unique_ptr<SecureSvm>(new SecureSvm(modelsDirectory, *fileIterator, this->clientPublicKey, shared_from_this())));
316  }
317 
318  }
319 
325  Server::EncryptedClusterVotes clusterVotes;
326 
328  for (unsigned int i = 0; i < this->medicalRelevanceClusterCount; ++i) {
329  //initialize the vote accumulator to [0]
330  clusterVotes.emplace_back(this->encryptedZero);
331 
332  for (unsigned int j = 0; j < this->medicalRelevanceClusterCount; ++j) {
334 
335  //the upper right triangle of the prediction values matrix is stored as an unraveled vector
336  //basically, we need to subtract sum(i - 1) = (i - 1) * i / 2 at each step to determine the index in the vector
337  //also, we subtract (i + 1) at each step because we need to account for the missing SVM(i, i) values
338  unsigned int index;
339 
340  //we don't have SVM(i, i)
341  if (i != j) {
342  //upper right triangle
343  if (i < j) {
344  index = i * (medicalRelevanceClusterCount - 1) - (i - 1) * i / 2 + j - (i + 1);
345 
346  clusterVotes[i] = clusterVotes[i] + votes[index];
347  }
348  else {
349  index = j * (medicalRelevanceClusterCount - 1) - (j - 1) * j / 2 + i - (j + 1);
350 
351  //prediction(i, j) = 1 - prediction(j, i)
352  clusterVotes[i] = clusterVotes[i] + this->encryptedOne - votes[index];
353  }
354  }
355  }//i
356  }//j
357 
358  return clusterVotes;
359  }
360 
376  SecurePermutation permutation(data.size());
377 
379  permutation.Permute(data);
380 
382  for (Server::EncryptedSvmValues::iterator encryptedSvmValue = data.begin(); encryptedSvmValue != data.end(); ++encryptedSvmValue) {
383  //this operation may cause an overflow on the plaintext data if the random number ends up being too large
384  *encryptedSvmValue = *encryptedSvmValue * RandomProvider::GetInstance().GetRandomInteger(this->blindingFactorSize);
385  }
386 
388  this->client.lock()->EvaluateSign(data);
389 
391  permutation.InvertPermutation(data);
392  }
393 
409  SecurePermutation permutation(data.size());
410 
412  permutation.Permute(data);
413 
415  BigInteger r1 = RandomProvider::GetInstance().GetRandomInteger(this->blindingFactorSize) + 1;
417  Paillier::Ciphertext r2 = this->cryptoProvider.EncryptInteger(RandomProvider::GetInstance().GetRandomInteger(this->blindingFactorSize));
418 
419  for (Server::EncryptedClusterVotes::iterator encryptedClusterVote = data.begin(); encryptedClusterVote != data.end(); ++encryptedClusterVote) {
420  //this operation may cause an overflow on the plaintext data if the random number ends up being too large
421  *encryptedClusterVote = *encryptedClusterVote * r1 + r2;
422  }
423 
425  this->client.lock()->EvaluateMaximum(data);
426 
428  permutation.InvertPermutation(data);
429  }
430 
431 }//namespace SecureRecommendations
432 }//namespace SeComLib
std::vector< Paillier::Ciphertext > EncryptedVector
Define a vector template specialization for vectors of encrypted data.
Definition: secure_svm.h:62
Ciphertext GetEncryptedOne(const bool randomized=true) const
Returns [1].
Processed test data container.
Definition: test_data_row.h:40
Paillier::Ciphertext encryptedOne
Precompute [1] for optimization purposes.
Ciphertext GetEncryptedZero(const bool randomized=true) const
Returns [0].
std::vector< std::unique_ptr< SecureSvm > > safetySvms
Safety Block SVM models.
void DebugValue(const Paillier::Ciphertext &value) const
Sends a request to the client to debug an encrypted value.
std::deque< std::string > safetySvmModelFiles
The safety block model files.
void interactiveMaximumEvaluation(Server::EncryptedClusterVotes &data) const
Performs the interactive maximum evaluation protocol inplace.
Server::EncryptedClusterVotes getTotalClusterVotes(Server::EncryptedSvmValues &votes) const
Returns the total number of votes per cluster.
void InteractiveSecureDivision(const BigInteger &numerator, SecureSvm::EncryptedVector &denominators) const
Performs the interactive secure division protocol.
void Initialize()
Initializes the internal SVM instances.
std::vector< BigInteger > dummyContentItems
Dummy vector of content items.
static Config & GetInstance()
Returns a reference to the singleton.
Definition: config.cpp:48
std::string medicalRelevanceModelsDirectory
Medical relevance models folder.
std::deque< std::string > GetSafetyBlockModelFiles() const
Returns the names of the model files for the safety block.
SecureSvm::EncryptedVector x
Encrypted test data vector.
Definition: test_data_row.h:43
std::vector< Paillier::Ciphertext > EncryptedClusterVotes
Container for the encrypted cluster votes.
virtual T_Ciphertext EncryptInteger(const BigInteger &plaintext) const
Encrypt an integer and apply randomization.
SecureSvm::KernelTypes kernel
The SVM kernel type.
PaillierPublicKey clientPublicKey
The client public key.
void loadMedicalRelevanceSvmModels(const std::string &modelsDirectory)
Loads the Medical Relevance block SVM models.
unsigned int contentItemCount
The number of content items.
std::vector< std::unique_ptr< SecureSvm > > medicalRelevanceSvms
Medical Relevance Block SVM models.
unsigned int medicalRelevanceClusterCount
Number of clusters for the Medical Relevance Block.
SecureSvm::EncryptedVector xSquared
Encrypted squared test data vector.
Definition: test_data_row.h:49
Permutation class which implements the Fisher-Yates (Knuth) shuffle algorithm.
static KernelTypes GetKernel(const std::string &input)
Converts the input string to the proper kernel.
Definition: secure_svm.cpp:170
Server(const PaillierPublicKey &key)
Constructor.
void GetPerformancePredictions(SecureSvm::EncryptedVector &firstTwoBlocksPredictions, SecureSvm::EncryptedVector &SafetyBlockPredictions, const TestDataRow &medicalRelevanceTestData, const std::map< std::string, TestDataRow > &safetyTestData) const
Procedure which computes (and returns) the predictions generated by the first two blocks and the safe...
Secure Support Vector Machine algorithm.
Definition: secure_svm.h:59
std::vector< Paillier::Ciphertext > EncryptedSvmValues
Vector of encrypted SVM evaluations.
void GetAccuracyPredictions(Server::EncryptedClusterVotes &clusterVotes, SecureSvm::EncryptedVector &safetyPredictions, const TestDataRow &medicalRelevanceTestData, const std::map< std::string, TestDataRow > &safetyTestData) const
Procedure which computes (and returns) the cluster votes for the first block and the medical safety p...
unsigned int blindingFactorSize
The maximum size of the blinding factors.
std::string modelFileExtension
The model files extension.
T GetParameter(const std::string &parameter) const
Template method which returns the value of the specified configuration parameter. ...
Definition: config.hpp:41
std::vector< Paillier::Ciphertext > dummyEncryptedPreferenceScores
Dummy vector of encrypted preference scores.
std::string safetyModelsDirectory
The full path to the safety block models directory.
SecureSvm::EncryptedVector xx
Encrypted vector product combinations, , stored as an unraveled upper triangular matrix.
Definition: test_data_row.h:46
void SetClient(const std::shared_ptr< const Hub > &client)
Sets a reference to the client.
void Permute(T_DataType &vector) const
Applies the permutations to the input vector.
void interactiveSignEvaluation(Server::EncryptedSvmValues &data) const
Performs the interactive sign evaluation protocol inplace.
Paillier::Ciphertext encryptedZero
Precompute [0] for optimization purposes.
static std::deque< std::string > GetFilesInDirectory(const std::string &directory)
Traverses the provided directory (non-recursively) and extracts the absolute paths to all the files i...
Definition: filesystem.cpp:40
void loadSafetySvmModels(const std::string &modelsDirectory)
Loads the Safety block SVM models.
Definition of class Server.
The public key container structure for the Paillier cryptosystem.
Definition: paillier.h:49
std::vector< std::string > GetSafetyBlockSvmsUnsafeClasses() const
Returns the unsafe classes for each safety SVM.