package org.adbcj.postgresql.codec.backend;
import java.nio.charset.Charset;
import java.util.HashMap;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.io.IOException;
import org.adbcj.Type;
import org.adbcj.Value;
import org.adbcj.postgresql.codec.PgField;
import org.adbcj.postgresql.codec.FormatCode;
import org.adbcj.postgresql.codec.ErrorField;
import org.adbcj.postgresql.codec.PgFieldType;
import org.adbcj.postgresql.codec.ConfigurationVariable;
import org.adbcj.postgresql.codec.ConnectionState;
import org.adbcj.support.DefaultValue;
import org.adbcj.support.DecoderInputStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
private final Logger logger = LoggerFactory.getLogger(BackendMessageDecoder.class);
private static final int MESSAGE_MIN_SIZE = 5;
private static final int FIELD_LENGTH_SIZE = 4;
private final ConnectionState connectionState;
this.connectionState = connectionState;
}
public AbstractBackendMessage
decode(DecoderInputStream input,
boolean block)
throws IOException {
if (!input.markSupported() && !block) {
throw new IllegalArgumentException("Non-blocking decoding requires an InputStream that supports marking");
}
input.mark(Integer.MAX_VALUE);
AbstractBackendMessage message = null;
try {
message = doDecode(input, block);
} finally {
if (message == null) {
input.reset();
}
}
return message;
}
private AbstractBackendMessage
doDecode(DecoderInputStream input,
boolean block)
throws IOException {
logger.trace("Decoding message");
if (!block && input.available() < MESSAGE_MIN_SIZE) {
return null;
}
byte typeValue = input.readByte();
int length = input.readInt() - FIELD_LENGTH_SIZE;
if (!block && input.available() < length) {
logger.trace("Need more data");
return null;
}
input.setLimit(length);
BackendMessageType type = BackendMessageType.fromValue(typeValue);
if (type == null) {
throw new IllegalStateException("Do not recognize message of type " + typeValue);
}
logger.debug("Decoding message of type {}", type);
switch (type) {
case BIND_COMPLETE:
case CLOSE_COMPLETE:
case COPY_DONE:
case EMPTY_QUERY_RESPONSE:
case NO_DATA:
case PARSE_COMPLETE:
case PORTAL_SUSPENDED:
return new SimpleBackendMessage(type);
case AUTHENTICATION:
return decodeAuthentication(input);
case COMMAND_COMPLETE:
return decodeCommandComplete(input);
case DATA_ROW:
return decodeDataRow(input);
case ERROR_RESPONSE:
return decodeError(input);
case KEY:
return decodeKey(input);
case PARAMETER_STATUS:
return decodeParameterStatus(input);
case READY_FOR_QUERY:
return decodeReadyForQuery(input);
case ROW_DESCRIPTION:
return decodeRowDescription(input);
case COPY_DATA:
case COPY_IN_RESPONSE:
case COPY_OUT_RESPONSE:
case FUNCTION_CALL_RESPONSE:
case NOTICE_RESPONSE:
case NOTIFICATION_RESPONSE:
case PARAMETER_DESCRIPTION:
case PASSWORD:
throw new IllegalStateException("No decoder implemented for message of type " + type);
default:
throw new IllegalStateException(String.format("Messages of type %s are not implemented", typeValue));
}
}
AuthenticationType authenticationType = AuthenticationType.values()[input.readInt()];
switch(authenticationType) {
case OK:
case KERBEROS_5:
case CLEARTEXT_PASSWORD:
case SCM_CREDENTIAL:
case GSS:
return new AuthenticationMessage(authenticationType);
case CRYPT_PASSWORD:
byte[] cryptSalt = new byte[2];
input.read(cryptSalt);
return new AuthenticationMessage(authenticationType, cryptSalt);
case MD5_PASSWORD:
byte[] md5Salt = new byte[FIELD_LENGTH_SIZE];
input.read(md5Salt);
return new AuthenticationMessage(authenticationType, md5Salt);
case GSS_CONTINUE:
byte[] data = new byte[input.getLimit()];
input.read(data);
return new AuthenticationMessage(authenticationType, data);
case UNKNOWN:
default:
throw new IllegalStateException("Don't know how to handle authentication type of " + authenticationType);
}
}
private static final Pattern COMMAND_PATTERN = Pattern.compile("(\\w+)\\s*(\\d*)\\s*(\\d*)");
Charset charset = connectionState.getBackendCharset();
String commandStr = input.readString(charset);
Matcher matcher = COMMAND_PATTERN.matcher(commandStr);
if (!matcher.matches()) {
throw new IllegalStateException(String.format("Unable to parse command completion string '%s'", commandStr));
}
Command command = Command.valueOf(matcher.group(1));
long count = -1;
int oid = -1;
if (matcher.group(3).length() > 0) {
oid = Integer.valueOf(matcher.group(2));
count = Long.valueOf(matcher.group(3));
} else if (matcher.group(2).length() > 0) {
count = Long.valueOf(matcher.group(2));
}
return new CommandCompleteMessage(command, count, oid);
}
private DataRowMessage
decodeDataRow(DecoderInputStream input)
throws IOException {
Charset charset = connectionState.getBackendCharset();
PgField[] fields = connectionState.getCurrentResultSetFields();
if (fields == null) {
throw new IllegalStateException("Received a data row without any field definitions in the request payload");
}
int fieldCount = input.readUnsignedShort();
Value[] values = new Value[fieldCount];
for (int i = 0; i < fieldCount; i++) {
int valueLength = input.readInt();
PgField field = fields[i];
Value value;
if (valueLength < 0) {
value = new DefaultValue(field, null);
} else {
String strVal;
switch (field.getColumnType()) {
case INTEGER:
switch (field.getFormatCode()) {
case BINARY:
value = new DefaultValue(field, input.readInt());
break;
case TEXT:
strVal = input.readString(valueLength, charset);
value = new DefaultValue(field, Integer.valueOf(strVal));
break;
default:
throw new IllegalStateException("Unable to decode format of " + field.getFormatCode());
}
break;
case BIGINT:
switch (field.getFormatCode()) {
case BINARY:
value = new DefaultValue(field, (long)input.readInt() << 32 | input.readInt());
break;
case TEXT:
strVal = input.readString(valueLength, charset);
value = new DefaultValue(field, Long.valueOf(strVal));
break;
default:
throw new IllegalStateException("Unable to decode format of " + field.getFormatCode());
}
break;
case VARCHAR:
strVal = input.readString(valueLength, charset);
value = new DefaultValue(field, strVal);
break;
default:
input.skip(valueLength);
throw new IllegalStateException("Unable to decode column of type " + field.getColumnType());
}
}
values[i] = value;
}
return new DataRowMessage(values);
}
private AbstractBackendMessage
decodeError(DecoderInputStream input)
throws IOException {
Map<ErrorField, String> fields = new HashMap<ErrorField, String>();
for(;;) {
byte token = input.readByte();
if (token == 0) {
break;
}
ErrorField field = ErrorField.toErrorField(token);
String value = input.readString(connectionState.getBackendCharset());
if (field == null) {
logger.warn("Unrecognized error field of type '{}' with the value '{}'", (char)token, value);
} else {
fields.put(field, value);
}
}
return new ErrorResponseMessage(fields);
}
private AbstractBackendMessage
decodeKey(DecoderInputStream input)
throws IOException {
int pid = input.readInt();
int key = input.readInt();
return new KeyMessage(pid, key);
}
Charset charset = connectionState.getBackendCharset();
String name = input.readString(charset);
String value = input.readString(charset);
ConfigurationVariable cv = ConfigurationVariable.fromName(name);
if (cv == null) {
logger.warn("No ConfigurationVariable entry for {}", name);
}
return new ParameterMessage(cv, value);
}
char s = (char)input.readByte();
Status status;
switch(s) {
case 'E':
status = Status.ERROR;
break;
case 'I':
status = Status.IDLE;
break;
case 'T':
status = Status.TRANSACTION;
break;
default:
throw new IllegalStateException("Unrecognized server status " + s);
}
return new ReadyMessage(status);
}
Charset charset = connectionState.getBackendCharset();
int fieldCount = input.readUnsignedShort();
PgField[] fields = new PgField[fieldCount];
for (int i = 0; i < fieldCount; i++) {
String name = input.readString(charset);
int tableOid = input.readInt();
int columnAttributeNumber = input.readUnsignedShort();
int typeOid = input.readInt();
short typeSize = input.readShort();
int typeModifier = input.readInt();
FormatCode code = FormatCode.values()[input.readShort()];
Type type;
switch (typeOid) {
case PgFieldType.BOOLEAN:
type = Type.BOOLEAN;
break;
case PgFieldType.BIGINT:
type = Type.BIGINT;
break;
case PgFieldType.CHAR:
type = Type.CHAR;
break;
case PgFieldType.DATE:
type = Type.DATE;
break;
case PgFieldType.DOUBLE:
type = Type.DOUBLE;
break;
case PgFieldType.INTEGER:
type = Type.INTEGER;
break;
case PgFieldType.REAL:
type = Type.REAL;
break;
case PgFieldType.SMALLINT:
type = Type.SMALLINT;
break;
case PgFieldType.VARCHAR:
type = Type.VARCHAR;
break;
default:
throw new IllegalStateException("Unable to handle field type with oid " + typeOid);
}
fields[i] = new PgField(
i,
connectionState.getDatabaseName(),
type,
name,
tableOid,
columnAttributeNumber,
code,
typeSize,
typeModifier
);
logger.debug("Setting fields for current result set: {}", fields);
connectionState.setCurrentResultSetFields(fields);
}
return new RowDescriptionMessage(fields);
}
}