add clock, change exceptions, add network exceptions, fix some crashes

This commit is contained in:
Eduardo Bart
2011-12-01 20:25:32 -02:00
parent 4afbe43e6f
commit d5e15d1f06
54 changed files with 442 additions and 274 deletions

View File

@@ -35,7 +35,7 @@ void InputMessage::reset()
uint8 InputMessage::getU8(bool peek)
{
assert(canRead(1));
checkRead(1);
uint8 v = m_buffer[m_readPos];
if(!peek)
@@ -46,7 +46,7 @@ uint8 InputMessage::getU8(bool peek)
uint16 InputMessage::getU16(bool peek)
{
assert(canRead(2));
checkRead(2);
uint16 v = *(uint16_t*)(m_buffer + m_readPos);
if(!peek)
@@ -57,7 +57,7 @@ uint16 InputMessage::getU16(bool peek)
uint32 InputMessage::getU32(bool peek)
{
assert(canRead(4));
checkRead(4);
uint32 v = *(uint32*)(m_buffer + m_readPos);
if(!peek)
@@ -68,7 +68,7 @@ uint32 InputMessage::getU32(bool peek)
uint64 InputMessage::getU64(bool peek)
{
assert(canRead(8));
checkRead(8);
uint64 v = *(uint64*)(m_buffer + m_readPos);
if(!peek)
@@ -80,7 +80,7 @@ uint64 InputMessage::getU64(bool peek)
std::string InputMessage::getString()
{
uint16 stringLength = getU16();
assert(canRead(stringLength));
checkRead(stringLength);
char* v = (char*)(m_buffer + m_readPos);
m_readPos += stringLength;
return std::string(v, stringLength);
@@ -92,3 +92,10 @@ bool InputMessage::canRead(int bytes)
return false;
return true;
}
void InputMessage::checkRead(int bytes)
{
if(!canRead(bytes))
throw NetworkException("InputMessage eof reached");
}

View File

@@ -24,6 +24,7 @@
#define INPUTMESSAGE_H
#include "declarations.h"
#include "networkexception.h"
class InputMessage
{
@@ -56,6 +57,7 @@ public:
private:
bool canRead(int bytes);
void checkRead(int bytes);
uint16 m_readPos;
uint16 m_messageSize;

View File

@@ -0,0 +1,12 @@
#ifndef NETWORKEXCEPTION_H
#define NETWORKEXCEPTION_H
#include "declarations.h"
class NetworkException : public Exception
{
public:
NetworkException(const std::string& what) : Exception(what) { }
};
#endif

View File

@@ -35,7 +35,7 @@ void OutputMessage::reset()
void OutputMessage::addU8(uint8 value)
{
assert(canWrite(1));
checkWrite(1);
m_buffer[m_writePos] = value;
m_writePos += 1;
m_messageSize += 1;
@@ -43,7 +43,7 @@ void OutputMessage::addU8(uint8 value)
void OutputMessage::addU16(uint16 value)
{
assert(canWrite(2));
checkWrite(2);
*(uint16_t*)(m_buffer + m_writePos) = value;
m_writePos += 2;
m_messageSize += 2;
@@ -51,7 +51,7 @@ void OutputMessage::addU16(uint16 value)
void OutputMessage::addU32(uint32 value)
{
assert(canWrite(4));
checkWrite(4);
*(uint32*)(m_buffer + m_writePos) = value;
m_writePos += 4;
m_messageSize += 4;
@@ -59,7 +59,7 @@ void OutputMessage::addU32(uint32 value)
void OutputMessage::addU64(uint64 value)
{
assert(canWrite(8));
checkWrite(8);
*(uint64*)(m_buffer + m_writePos) = value;
m_writePos += 8;
m_messageSize += 8;
@@ -68,7 +68,9 @@ void OutputMessage::addU64(uint64 value)
void OutputMessage::addString(const char* value)
{
size_t stringLength = strlen(value);
assert(stringLength < 0xFFFF && canWrite(stringLength + 2));
if(stringLength > 65535)
throw NetworkException("[OutputMessage::addString] string length > 65535");
checkWrite(stringLength + 2);
addU16(stringLength);
strcpy((char*)(m_buffer + m_writePos), value);
m_writePos += stringLength;
@@ -82,7 +84,9 @@ void OutputMessage::addString(const std::string &value)
void OutputMessage::addPaddingBytes(int bytes, uint8 byte)
{
assert(canWrite(bytes) && bytes >= 0);
if(bytes <= 0)
return;
checkWrite(bytes);
memset((void*)&m_buffer[m_writePos], byte, bytes);
m_writePos += bytes;
m_messageSize += bytes;
@@ -94,3 +98,9 @@ bool OutputMessage::canWrite(int bytes)
return false;
return true;
}
void OutputMessage::checkWrite(int bytes)
{
if(!canWrite(bytes))
throw NetworkException("OutputMessage max buffer size reached");
}

View File

@@ -24,6 +24,7 @@
#define OUTPUTMESSAGE_H
#include "declarations.h"
#include "networkexception.h"
class OutputMessage
{
@@ -56,6 +57,7 @@ public:
private:
bool canWrite(int bytes);
void checkWrite(int bytes);
uint16 m_writePos;
uint16 m_messageSize;

View File

@@ -107,14 +107,15 @@ void Protocol::internalRecvData(uint8* buffer, uint16 size)
if(m_checksumEnabled) {
uint32 checksum = getAdlerChecksum(m_inputMessage.getBuffer() + InputMessage::DATA_POS, m_inputMessage.getMessageSize() - InputMessage::CHECKSUM_LENGTH);
if(m_inputMessage.getU32() != checksum) {
// error
logError("got a network message with invalid checksum");
logTraceError("got a network message with invalid checksum");
return;
}
}
if(m_xteaEncryptionEnabled)
xteaDecrypt(m_inputMessage);
if(m_xteaEncryptionEnabled) {
if(!xteaDecrypt(m_inputMessage))
return;
}
onRecv(m_inputMessage);
}
@@ -133,7 +134,7 @@ bool Protocol::xteaDecrypt(InputMessage& inputMessage)
{
uint16 messageSize = inputMessage.getMessageSize() - InputMessage::CHECKSUM_LENGTH;
if(messageSize % 8 != 0) {
logError("invalid encrypted network message");
logTraceError("invalid encrypted network message");
return false;
}
@@ -156,7 +157,7 @@ bool Protocol::xteaDecrypt(InputMessage& inputMessage)
int tmp = inputMessage.getU16();
if(tmp > inputMessage.getMessageSize() - 4) {
logError("invalid decrypted a network message");
logTraceError("invalid decrypted a network message");
return false;
}

View File

@@ -65,9 +65,11 @@ void Rsa::setKey(const char* p, const char* q, const char* d)
mpz_clear(pm1);
mpz_clear(qm1);
m_keySet = true;
}
bool Rsa::encrypt(char* msg, int32_t size, const char* key)
void Rsa::encrypt(char* msg, int32_t size, const char* key)
{
mpz_t plain, c;
mpz_init2(plain, 1024);
@@ -92,11 +94,13 @@ bool Rsa::encrypt(char* msg, int32_t size, const char* key)
mpz_clear(plain);
mpz_clear(e);
mpz_clear(mod);
return true;
}
bool Rsa::decrypt(char* msg, int32_t size)
{
if(!m_keySet)
return false;
mpz_t c,v1,v2,u2,tmp;
mpz_init2(c, 1024);
mpz_init2(v1, 1024);
@@ -130,6 +134,5 @@ bool Rsa::decrypt(char* msg, int32_t size)
mpz_clear(v2);
mpz_clear(u2);
mpz_clear(tmp);
return true;
}

View File

@@ -34,7 +34,7 @@ public:
void setKey(const char* p, const char* q, const char* d);
bool decrypt(char* msg, int32_t size);
static bool encrypt(char* msg, int32_t size, const char* key);
static void encrypt(char* msg, int32_t size, const char* key);
protected:
bool m_keySet;