Skip to content

Commit

Permalink
Reflect PR #42 KeyValuePairEnumerableObject
Browse files Browse the repository at this point in the history
  • Loading branch information
C-SELLERS committed Feb 4, 2021
1 parent 3a6fdf5 commit de07b9a
Show file tree
Hide file tree
Showing 4 changed files with 369 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/runtime/classmanager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,11 @@ private static ClassBase CreateClass(Type type)
impl = new ArrayObject(type);
}

else if (type.IsKeyValuePairEnumerable())
{
impl = new KeyValuePairEnumerableObject(type);
}

else if (type.IsInterface)
{
impl = new InterfaceObject(type);
Expand Down
112 changes: 112 additions & 0 deletions src/runtime/keyvaluepairenumerableobject.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
using System;
using System.Collections.Generic;
using System.Reflection;

namespace Python.Runtime
{
/// <summary>
/// Implements a Python type for managed KeyValuePairEnumerable (dictionaries).
/// This type is essentially the same as a ClassObject, except that it provides
/// sequence semantics to support natural dictionary usage (__contains__ and __len__)
/// from Python.
/// </summary>
internal class KeyValuePairEnumerableObject : ClassObject
{
private static Dictionary<Tuple<Type, string>, MethodInfo> methodsByType = new Dictionary<Tuple<Type, string>, MethodInfo>();
private static List<string> requiredMethods = new List<string> { "Count", "ContainsKey" };

internal static bool VerifyMethodRequirements(Type type)
{
foreach (var requiredMethod in requiredMethods)
{
var method = type.GetMethod(requiredMethod);
if (method == null)
{
method = type.GetMethod($"get_{requiredMethod}");
if (method == null)
{
return false;
}
}

var key = Tuple.Create(type, requiredMethod);
methodsByType.Add(key, method);
}

return true;
}

internal KeyValuePairEnumerableObject(Type tp) : base(tp)
{

}

internal override bool CanSubclass() => false;

/// <summary>
/// Implements __len__ for dictionary types.
/// </summary>
public static int mp_length(IntPtr ob)
{
var obj = (CLRObject)GetManagedObject(ob);
var self = obj.inst;

var key = Tuple.Create(self.GetType(), "Count");
var methodInfo = methodsByType[key];

return (int)methodInfo.Invoke(self, null);
}

/// <summary>
/// Implements __contains__ for dictionary types.
/// </summary>
public static int sq_contains(IntPtr ob, IntPtr v)
{
var obj = (CLRObject)GetManagedObject(ob);
var self = obj.inst;

var key = Tuple.Create(self.GetType(), "ContainsKey");
var methodInfo = methodsByType[key];

var parameters = methodInfo.GetParameters();
object arg;
if (!Converter.ToManaged(v, parameters[0].ParameterType, out arg, false))
{
Exceptions.SetError(Exceptions.TypeError,
$"invalid parameter type for sq_contains: should be {Converter.GetTypeByAlias(v)}, found {parameters[0].ParameterType}");
}

return (bool)methodInfo.Invoke(self, new[] { arg }) ? 1 : 0;
}
}

public static class KeyValuePairEnumerableObjectExtension
{
public static bool IsKeyValuePairEnumerable(this Type type)
{
var iEnumerableType = typeof(IEnumerable<>);
var keyValuePairType = typeof(KeyValuePair<,>);

var interfaces = type.GetInterfaces();
foreach (var i in interfaces)
{
if (i.IsGenericType &&
i.GetGenericTypeDefinition() == iEnumerableType)
{
var arguments = i.GetGenericArguments();
if (arguments.Length != 1) continue;

var a = arguments[0];
if (a.IsGenericType &&
a.GetGenericTypeDefinition() == keyValuePairType &&
a.GetGenericArguments().Length == 2)
{
return KeyValuePairEnumerableObject.VerifyMethodRequirements(type);
}
}
}

return false;
}
}
}
106 changes: 106 additions & 0 deletions src/testing/dictionarytest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
using System.Collections;
using System.Collections.Generic;
using System.Linq;

namespace Python.Test
{
/// <summary>
/// Supports units tests for dictionary __contains__ and __len__
/// </summary>
public class PublicDictionaryTest
{
public IDictionary<string, int> items;

public PublicDictionaryTest()
{
items = new int[5] { 0, 1, 2, 3, 4 }
.ToDictionary(k => k.ToString(), v => v);
}
}


public class ProtectedDictionaryTest
{
protected IDictionary<string, int> items;

public ProtectedDictionaryTest()
{
items = new int[5] { 0, 1, 2, 3, 4 }
.ToDictionary(k => k.ToString(), v => v);
}
}


public class InternalDictionaryTest
{
internal IDictionary<string, int> items;

public InternalDictionaryTest()
{
items = new int[5] { 0, 1, 2, 3, 4 }
.ToDictionary(k => k.ToString(), v => v);
}
}


public class PrivateDictionaryTest
{
private IDictionary<string, int> items;

public PrivateDictionaryTest()
{
items = new int[5] { 0, 1, 2, 3, 4 }
.ToDictionary(k => k.ToString(), v => v);
}
}

public class InheritedDictionaryTest : IDictionary<string, int>
{
private readonly IDictionary<string, int> items;

public InheritedDictionaryTest()
{
items = new int[5] { 0, 1, 2, 3, 4 }
.ToDictionary(k => k.ToString(), v => v);
}

public int this[string key]
{
get { return items[key]; }
set { items[key] = value; }
}

public ICollection<string> Keys => items.Keys;

public ICollection<int> Values => items.Values;

public int Count => items.Count;

public bool IsReadOnly => false;

public void Add(string key, int value) => items.Add(key, value);

public void Add(KeyValuePair<string, int> item) => items.Add(item);

public void Clear() => items.Clear();

public bool Contains(KeyValuePair<string, int> item) => items.Contains(item);

public bool ContainsKey(string key) => items.ContainsKey(key);

public void CopyTo(KeyValuePair<string, int>[] array, int arrayIndex)
{
items.CopyTo(array, arrayIndex);
}

public IEnumerator<KeyValuePair<string, int>> GetEnumerator() => items.GetEnumerator();

public bool Remove(string key) => items.Remove(key);

public bool Remove(KeyValuePair<string, int> item) => items.Remove(item);

public bool TryGetValue(string key, out int value) => items.TryGetValue(key, out value);

IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
}
}
146 changes: 146 additions & 0 deletions src/tests/test_dictionary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# -*- coding: utf-8 -*-

"""Test support for managed dictionaries."""

import Python.Test as Test
import System
import pytest


def test_public_dict():
"""Test public dict."""
ob = Test.PublicDictionaryTest()
items = ob.items

assert len(items) == 5

assert items['0'] == 0
assert items['4'] == 4

items['0'] = 8
assert items['0'] == 8

items['4'] = 9
assert items['4'] == 9

items['-4'] = 0
assert items['-4'] == 0

items['-1'] = 4
assert items['-1'] == 4

def test_protected_dict():
"""Test protected dict."""
ob = Test.ProtectedDictionaryTest()
items = ob.items

assert len(items) == 5

assert items['0'] == 0
assert items['4'] == 4

items['0'] = 8
assert items['0'] == 8

items['4'] = 9
assert items['4'] == 9

items['-4'] = 0
assert items['-4'] == 0

items['-1'] = 4
assert items['-1'] == 4

def test_internal_dict():
"""Test internal dict."""

with pytest.raises(AttributeError):
ob = Test.InternalDictionaryTest()
_ = ob.items

def test_private_dict():
"""Test private dict."""

with pytest.raises(AttributeError):
ob = Test.PrivateDictionaryTest()
_ = ob.items

def test_dict_contains():
"""Test dict support for __contains__."""

ob = Test.PublicDictionaryTest()
items = ob.items

assert '0' in items
assert '1' in items
assert '2' in items
assert '3' in items
assert '4' in items

assert not ('5' in items)
assert not ('-1' in items)

def test_dict_abuse():
"""Test dict abuse."""
_class = Test.PublicDictionaryTest
ob = Test.PublicDictionaryTest()

with pytest.raises(AttributeError):
del _class.__getitem__

with pytest.raises(AttributeError):
del ob.__getitem__

with pytest.raises(AttributeError):
del _class.__setitem__

with pytest.raises(AttributeError):
del ob.__setitem__

with pytest.raises(TypeError):
Test.PublicArrayTest.__getitem__(0, 0)

with pytest.raises(TypeError):
Test.PublicArrayTest.__setitem__(0, 0, 0)

with pytest.raises(TypeError):
desc = Test.PublicArrayTest.__dict__['__getitem__']
desc(0, 0)

with pytest.raises(TypeError):
desc = Test.PublicArrayTest.__dict__['__setitem__']
desc(0, 0, 0)

def test_InheritedDictionary():
"""Test class that inherited from IDictionary."""
items = Test.InheritedDictionaryTest()

assert len(items) == 5

assert items['0'] == 0
assert items['4'] == 4

items['0'] = 8
assert items['0'] == 8

items['4'] = 9
assert items['4'] == 9

items['-4'] = 0
assert items['-4'] == 0

items['-1'] = 4
assert items['-1'] == 4

def test_InheritedDictionary_contains():
"""Test dict support for __contains__ in class that inherited from IDictionary"""
items = Test.InheritedDictionaryTest()

assert '0' in items
assert '1' in items
assert '2' in items
assert '3' in items
assert '4' in items

assert not ('5' in items)
assert not ('-1' in items)

0 comments on commit de07b9a

Please sign in to comment.