Skip to content

Commit

Permalink
Merge pull request #2 from Anvilcraft/update-zig
Browse files Browse the repository at this point in the history
port to latest zig
  • Loading branch information
star-tek-mb authored Dec 28, 2023
2 parents fde6e11 + 128b5a7 commit 378097d
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 41 deletions.
14 changes: 7 additions & 7 deletions src/auth.zig
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ pub const Scram = struct {
pub fn writeTo(self: *State, wb: *WriteBuffer) void {
switch (self.*) {
.update => |u| {
var len = "n,,n=,r=".len + u.nonce.len;
const len = "n,,n=,r=".len + u.nonce.len;
wb.writeString("SCRAM-SHA-256");
wb.writeInt(u32, @as(u32, @intCast(len)));
wb.writeBytes("n,,n=,r=");
Expand Down Expand Up @@ -113,10 +113,10 @@ pub const Scram = struct {
}

var decoded_salt_buf: [32]u8 = undefined;
var decoded_salt_len = try Base64.Decoder.calcSizeForSlice(salt);
const decoded_salt_len = try Base64.Decoder.calcSizeForSlice(salt);
if (decoded_salt_len > 32) return error.OutOfMemory;
try Base64.Decoder.decode(&decoded_salt_buf, salt);
var decoded_salt = decoded_salt_buf[0..decoded_salt_len];
const decoded_salt = decoded_salt_buf[0..decoded_salt_len];

var salted_password = hi(self.state.update.password, decoded_salt, try std.fmt.parseInt(usize, iterations, 10));
var hmac = Hmac.init(&salted_password);
Expand Down Expand Up @@ -168,12 +168,12 @@ pub const Scram = struct {
if (std.meta.activeTag(self.state) != .finish) return error.InvalidState;
if (message[0] != 'v' and message.len <= 2) return error.InvalidInput;

var verifier = message[2..];
const verifier = message[2..];
var verifier_buf: [128]u8 = undefined;
var verifier_len = try Base64.Decoder.calcSizeForSlice(verifier);
const verifier_len = try Base64.Decoder.calcSizeForSlice(verifier);
if (verifier_len > 128) return error.OutOfMemory;
try Base64.Decoder.decode(&verifier_buf, verifier);
var decoded_verified = verifier_buf[0..verifier_len];
const decoded_verified = verifier_buf[0..verifier_len];

var hmac = Hmac.init(&self.state.finish.salted_password);
hmac.update("Server Key");
Expand Down Expand Up @@ -227,7 +227,7 @@ test "scram-sha-256" {
defer wb.deinit();

var scram = Scram.init(password);
std.mem.copy(u8, scram.state.update.nonce[0..], nonce[0..]);
@memcpy(scram.state.update.nonce[0..nonce.len], nonce);
scram.state.writeTo(&wb);
try std.testing.expectEqualStrings(client_first, wb.buf.items[23..]);

Expand Down
6 changes: 3 additions & 3 deletions src/encdec.zig
Original file line number Diff line number Diff line change
Expand Up @@ -140,17 +140,17 @@ pub fn quoteLiteral(allocator: std.mem.Allocator, literal: []const u8) ![]const
}

test "quote identifier" {
var id = try quoteIdentifier(std.testing.allocator, "my_table");
const id = try quoteIdentifier(std.testing.allocator, "my_table");
defer std.testing.allocator.free(id);
try std.testing.expectEqualStrings("\"my_table\"", id);
}

test "quote literal" {
var q1 = try quoteLiteral(std.testing.allocator, "hello '' world");
const q1 = try quoteLiteral(std.testing.allocator, "hello '' world");
defer std.testing.allocator.free(q1);
try std.testing.expectEqualStrings("'hello '''' world'", q1);

var q2 = try quoteLiteral(std.testing.allocator, "hello \\'\\' world");
const q2 = try quoteLiteral(std.testing.allocator, "hello \\'\\' world");
defer std.testing.allocator.free(q2);
try std.testing.expectEqualStrings(" E'hello \\\\''\\\\'' world'", q2);
}
18 changes: 9 additions & 9 deletions src/messaging.zig
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,20 @@ pub const ReadBuffer = struct {
}

pub fn readInt(self: *ReadBuffer, comptime T: type) T {
var ret = std.mem.readIntBig(T, self.buf[self.pos..][0..@sizeOf(T)]);
const ret = std.mem.readInt(T, self.buf[self.pos..][0..@sizeOf(T)], .big);
self.pos += @sizeOf(T);
return ret;
}

pub fn readString(self: *ReadBuffer) []const u8 {
var start = self.pos;
const start = self.pos;
while (self.buf[self.pos] != 0 and self.pos < self.buf.len) : (self.pos += 1) {}
self.pos += 1;
return self.buf[start .. self.pos - 1];
}

pub fn readBytes(self: *ReadBuffer, num: u32) []const u8 {
var ret = self.buf[self.pos .. self.pos + num];
const ret = self.buf[self.pos .. self.pos + num];
self.pos += num;
return ret;
}
Expand Down Expand Up @@ -51,7 +51,7 @@ pub const WriteBuffer = struct {
}

pub fn writeInt(self: *WriteBuffer, comptime T: type, value: T) void {
self.buf.writer().writeIntBig(T, value) catch {};
self.buf.writer().writeInt(T, value, .big) catch {};
}

pub fn writeString(self: *WriteBuffer, string: []const u8) void {
Expand All @@ -65,9 +65,9 @@ pub const WriteBuffer = struct {

pub fn finalize(self: *WriteBuffer) void {
if (self.tag == null) {
std.mem.writeIntBig(u32, self.buf.items[self.index..][0..4], @as(u32, @intCast(self.buf.items.len - self.index)));
std.mem.writeInt(u32, self.buf.items[self.index..][0..4], @as(u32, @intCast(self.buf.items.len - self.index)), .big);
} else {
std.mem.writeIntBig(u32, self.buf.items[self.index + 1 ..][0..4], @as(u32, @intCast(self.buf.items.len - self.index - 1)));
std.mem.writeInt(u32, self.buf.items[self.index + 1 ..][0..4], @as(u32, @intCast(self.buf.items.len - self.index - 1)), .big);
}
}

Expand Down Expand Up @@ -106,10 +106,10 @@ pub const Message = struct {
pub fn read(allocator: std.mem.Allocator, reader: anytype) !Message {
var type_and_len: [5]u8 = undefined;
_ = try reader.read(&type_and_len);
var @"type" = type_and_len[0];
var len = std.mem.readIntBig(u32, type_and_len[1..][0..4]);
const @"type" = type_and_len[0];
const len = std.mem.readInt(u32, type_and_len[1..][0..4], .big);
if (len > 4) {
var msg = try allocator.alloc(u8, len - 4);
const msg = try allocator.alloc(u8, len - 4);
_ = try reader.read(msg);
return Message{ .type = @"type", .len = len, .msg = msg };
}
Expand Down
46 changes: 24 additions & 22 deletions src/pgz.zig
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ pub const Connection = struct {
/// caller owns memory, release memory with `statement.deinit()`
pub fn prepare(self: *Connection, sql: []const u8) !Statement {
var name_buffer: [10]u8 = undefined; // 4294967295 - max value - length 10
var name = try std.fmt.bufPrint(&name_buffer, "{d}", .{self.statement_count});
const name = try std.fmt.bufPrint(&name_buffer, "{d}", .{self.statement_count});
self.statement_count += 1;

var wb = try WriteBuffer.init(self.allocator, 'P');
Expand Down Expand Up @@ -191,11 +191,11 @@ pub const Connection = struct {
if (msg.type != 'R') return error.UnexpectedMessage;

var buffer = ReadBuffer.init(msg.msg);
var password_type = buffer.readInt(u32);
const password_type = buffer.readInt(u32);
switch (password_type) {
0 => {},
5 => {
var salt = buffer.readBytes(4);
const salt = buffer.readBytes(4);

var md5 = auth.md5(user, password, salt);
var wb = try WriteBuffer.init(self.allocator, 'p');
Expand All @@ -206,7 +206,7 @@ pub const Connection = struct {
var check_msg = try Message.read(self.allocator, self.stream.reader());
defer check_msg.free(self.allocator);
var check_buffer = ReadBuffer.init(check_msg.msg);
var status = check_buffer.readInt(u32);
const status = check_buffer.readInt(u32);
if (status != 0) return error.AuthenticationError;
},
10 => {
Expand Down Expand Up @@ -266,16 +266,16 @@ pub const Connection = struct {
'Z' => break,
'T' => {
var buffer = ReadBuffer.init(msg.msg);
var num_rows = buffer.readInt(u16);
const num_rows = buffer.readInt(u16);
try row_headers.ensureTotalCapacity(self.allocator, num_rows);
for (0..num_rows) |_| {
var name = try self.allocator.dupe(u8, buffer.readString());
const name = try self.allocator.dupe(u8, buffer.readString());
_ = buffer.readInt(u32);
_ = buffer.readInt(u16);
var data_type = buffer.readInt(u32);
const data_type = buffer.readInt(u32);
_ = buffer.readInt(u16);
_ = buffer.readInt(u32);
var text_or_binary = buffer.readInt(u16);
const text_or_binary = buffer.readInt(u16);
try row_headers.append(self.allocator, RowHeader{
.name = name,
.type = data_type,
Expand All @@ -285,11 +285,11 @@ pub const Connection = struct {
},
'D' => {
var buffer = ReadBuffer.init(msg.msg);
var num_rows = buffer.readInt(u16);
const num_rows = buffer.readInt(u16);
var row: T = undefined;

for (0..num_rows) |i| {
var len = buffer.readInt(u32);
const len = buffer.readInt(u32);
var value: ?[]const u8 = undefined;
if (len == @as(u32, @truncate(-1))) {
value = null;
Expand Down Expand Up @@ -332,18 +332,20 @@ pub const Connection = struct {
while (code != 0) : (code = rb.readInt(u8)) {
switch (code) {
'S', 'V' => {
std.mem.copy(u8, self.last_error.?.severity[0..], rb.readString());
const s = rb.readString();
@memcpy(self.last_error.?.severity[0..s.len], s);
},
'C' => {
std.mem.copy(u8, self.last_error.?.code[0..], rb.readString());
const s = rb.readString();
@memcpy(self.last_error.?.code[0..s.len], s);
},
'M' => {
var message = rb.readString();
const message = rb.readString();
if (message.len > 256) {
self.last_error.?.length = 0;
} else {
self.last_error.?.length = @as(u32, @intCast(message.len));
std.mem.copy(u8, self.last_error.?.message[0..], message);
@memcpy(self.last_error.?.message[0..message.len], message);
}
},
else => {
Expand All @@ -368,7 +370,7 @@ pub const Statement = struct {
/// deinitializes and frees allocated memory
pub fn deinit(self: *Statement) void {
var name_buffer: [10]u8 = undefined; // 4294967295 - max value - length 10
var name = std.fmt.bufPrint(&name_buffer, "{d}", .{self.statement}) catch return;
const name = std.fmt.bufPrint(&name_buffer, "{d}", .{self.statement}) catch return;
var buffer = WriteBuffer.init(self.connection.allocator, 'C') catch return;
defer buffer.deinit();
buffer.writeInt(u8, 'C');
Expand Down Expand Up @@ -404,7 +406,7 @@ pub const Statement = struct {

fn sendExec(self: *Statement, args: anytype) !void {
var name_buffer: [10]u8 = undefined; // 4294967295 - max value - length 10
var name = try std.fmt.bufPrint(&name_buffer, "{d}", .{self.statement});
const name = try std.fmt.bufPrint(&name_buffer, "{d}", .{self.statement});

var wb = try WriteBuffer.init(self.connection.allocator, 'B');
defer wb.deinit();
Expand All @@ -416,7 +418,7 @@ pub const Statement = struct {
if ((@typeInfo(field.type) == .Optional or @typeInfo(field.type) == .Null) and @field(args, field.name) == null) {
wb.writeInt(u32, @as(u32, @truncate(-1)));
} else {
var encoded = try encdec.encode(self.connection.allocator, @field(args, field.name));
const encoded = try encdec.encode(self.connection.allocator, @field(args, field.name));
defer self.connection.allocator.free(encoded);
wb.writeInt(u32, @as(u32, @intCast(encoded.len)));
wb.writeBytes(encoded);
Expand All @@ -436,8 +438,8 @@ fn parseAffectedRows(command: []const u8) u32 {

var tokenizer = std.mem.tokenize(u8, command, " ");
_ = tokenizer.next(); // INSERT or SELECT
var second = tokenizer.next(); // 0 or affected rows
var maybe_last = tokenizer.next(); // affected rows or EOF
const second = tokenizer.next(); // 0 or affected rows
const maybe_last = tokenizer.next(); // affected rows or EOF
if (maybe_last) |last| {
return std.fmt.parseInt(u32, last, 10) catch 0;
} else {
Expand All @@ -451,7 +453,7 @@ test "connect" {
}

test "wrong auth" {
var res = Connection.init(std.testing.allocator, try std.Uri.parse("postgres://testing:wrong@localhost:5432/testing"));
const res = Connection.init(std.testing.allocator, try std.Uri.parse("postgres://testing:wrong@localhost:5432/testing"));
try std.testing.expectError(error.AuthenticationError, res);
}

Expand Down Expand Up @@ -527,8 +529,8 @@ test "encoding decoding null" {
defer conn.deinit();
var stmt = try conn.prepare("SELECT $1, $2;");
defer stmt.deinit();
var a: ?u32 = null;
var b: ?[]const u8 = "hi";
const a: ?u32 = null;
const b: ?[]const u8 = "hi";
var result = try stmt.query(struct { ?u8, ?[]const u8 }, .{ a, b });
defer result.deinit();
try std.testing.expectEqual(@as(usize, 1), result.data.len);
Expand Down

0 comments on commit 378097d

Please sign in to comment.