Commit 2f09b7ca authored by Amanda Lund's avatar Amanda Lund

Fixed some univariate/multivariate from_xml_element issues and updated unit tests

parent 74c4b493
......@@ -980,7 +980,7 @@ class Settings(object):
value = get_text(elem, key)
if value is not None:
if key in ('summary', 'tallies'):
value = value == 'true'
value = value in ('true', '1')
self.output[key] = value
def _statepoint_from_xml_element(self, root):
......@@ -997,9 +997,9 @@ class Settings(object):
value = get_text(elem, key)
if value is not None:
if key in ('separate', 'write'):
value = value == 'true'
value = value in ('true', '1')
elif key == 'overwrite_latest':
value = value == 'true'
value = value in ('true', '1')
key = 'overwrite'
else:
value = [int(x) for x in value.split()]
......@@ -1008,7 +1008,7 @@ class Settings(object):
def _confidence_intervals_from_xml_element(self, root):
text = get_text(root, 'confidence_intervals')
if text is not None:
self.confidence_intervals = text == 'true'
self.confidence_intervals = text in ('true', '1')
def _electron_treatment_from_xml_element(self, root):
text = get_text(root, 'electron_treatment')
......@@ -1028,12 +1028,12 @@ class Settings(object):
def _photon_transport_from_xml_element(self, root):
text = get_text(root, 'photon_transport')
if text is not None:
self.photon_transport = text == 'true'
self.photon_transport = text in ('true', '1')
def _ptables_from_xml_element(self, root):
text = get_text(root, 'ptables')
if text is not None:
self.ptables = text == 'true'
self.ptables = text in ('true', '1')
def _seed_from_xml_element(self, root):
text = get_text(root, 'seed')
......@@ -1043,7 +1043,7 @@ class Settings(object):
def _survival_biasing_from_xml_element(self, root):
text = get_text(root, 'survival_biasing')
if text is not None:
self.survival_biasing = text == 'true'
self.survival_biasing = text in ('true', '1')
def _cutoff_from_xml_element(self, root):
elem = root.find('cutoff')
......@@ -1066,7 +1066,7 @@ class Settings(object):
def _trigger_from_xml_element(self, root):
elem = root.find('trigger')
if elem is not None:
self.trigger_active = get_text(elem, 'active') == 'true'
self.trigger_active = get_text(elem, 'active') in ('true', '1')
text = get_text(elem, 'max_batches')
if text is not None:
self.trigger_max_batches = int(text)
......@@ -1077,7 +1077,7 @@ class Settings(object):
def _no_reduce_from_xml_element(self, root):
text = get_text(root, 'no_reduce')
if text is not None:
self.no_reduce = text == 'true'
self.no_reduce = text in ('true', '1')
def _verbosity_from_xml_element(self, root):
text = get_text(root, 'verbosity')
......@@ -1088,7 +1088,7 @@ class Settings(object):
elem = root.find('tabular_legendre')
if elem is not None:
text = get_text(elem, 'enable')
self.tabular_legendre['enable'] = text == 'true'
self.tabular_legendre['enable'] = text in ('true', '1')
text = get_text(elem, 'num_points')
if text is not None:
self.tabular_legendre['num_points'] = int(text)
......@@ -1108,7 +1108,7 @@ class Settings(object):
self.temperature['range'] = [float(x) for x in text.split()]
text = get_text(root, 'temperature_multipole')
if text is not None:
self.temperature['multipole'] = text == 'true'
self.temperature['multipole'] = text in ('true', '1')
def _trace_from_xml_element(self, root):
text = get_text(root, 'trace')
......@@ -1136,7 +1136,7 @@ class Settings(object):
value = get_text(elem, key)
if value is not None:
if key == 'enable':
value = value == 'true'
value = value in ('true', '1')
elif key in ('energy_min', 'energy_max'):
value = float(value)
elif key == 'nuclides':
......@@ -1146,7 +1146,7 @@ class Settings(object):
def _create_fission_neutrons_from_xml_element(self, root):
text = get_text(root, 'create_fission_neutrons')
if text is not None:
self.create_fission_neutrons = text == 'true'
self.create_fission_neutrons = text in ('true', '1')
def _log_grid_bins_from_xml_element(self, root):
text = get_text(root, 'log_grid_bins')
......@@ -1156,7 +1156,7 @@ class Settings(object):
def _dagmc_from_xml_element(self, root):
text = get_text(root, 'dagmc')
if text is not None:
self.dagmc = text == 'true'
self.dagmc = text in ('true', '1')
def export_to_xml(self, path='settings.xml'):
"""Export simulation settings to an XML file.
......
......@@ -3,11 +3,8 @@ import sys
from xml.etree import ElementTree as ET
from openmc._xml import get_text
from openmc.stats.univariate import (Univariate, Discrete, Uniform, Maxwell,
Watt, Normal, Muir, Tabular)
from openmc.stats.multivariate import (UnitSphere, Spatial, PolarAzimuthal,
Isotropic, Monodirectional, Box, Point,
CartesianIndependent)
from openmc.stats.univariate import Univariate
from openmc.stats.multivariate import UnitSphere, Spatial
import openmc.checkvalue as cv
......@@ -173,40 +170,14 @@ class Source(object):
space = elem.find('space')
if space is not None:
space_type = get_text(space, 'type')
if space_type == 'cartesian':
source.space = CartesianIndependent.from_xml_element(space)
elif space_type == 'box' or space_type == 'fission':
source.space = Box.from_xml_element(space)
elif space_type == 'point':
source.space = Point.from_xml_element(space)
source.space = Spatial.from_xml_element(space)
angle = elem.find('angle')
if angle is not None:
angle_type = get_text(angle, 'type')
if angle_type == 'mu-phi':
source.angle = PolarAzimuthal.from_xml_element(angle)
elif angle_type == 'isotropic':
source.angle = Isotropic.from_xml_element(angle)
elif angle_type == 'monodirectional':
source.angle = Monodirectional.from_xml_element(angle)
source.angle = UnitSphere.from_xml_element(angle)
energy = elem.find('energy')
if energy is not None:
energy_type = get_text(energy, 'type')
if energy_type == 'discrete':
source.energy = Discrete.from_xml_element(energy)
elif energy_type == 'uniform':
source.energy = Uniform.from_xml_element(energy)
elif energy_type == 'maxwell':
source.energy = Maxwell.from_xml_element(energy)
elif energy_type == 'watt':
source.energy = Watt.from_xml_element(energy)
elif energy_type == 'normal':
source.energy = Normal.from_xml_element(energy)
elif energy_type == 'muir':
source.energy = Muir.from_xml_element(energy)
elif energy_type == 'tabular':
source.energy = Tabular.from_xml_element(energy)
source.energy = Univariate.from_xml_element(energy)
return source
......@@ -51,7 +51,13 @@ class UnitSphere(metaclass=ABCMeta):
@classmethod
@abstractmethod
def from_xml_element(cls, elem):
pass
distribution = get_text(elem, 'type')
if distribution == 'mu-phi':
return PolarAzimuthal.from_xml_element(elem)
elif distribution == 'isotropic':
return Isotropic.from_xml_element(elem)
elif distribution == 'monodirectional':
return Monodirectional.from_xml_element(elem)
class PolarAzimuthal(UnitSphere):
......@@ -146,8 +152,8 @@ class PolarAzimuthal(UnitSphere):
params = get_text(elem, 'parameters')
if params is not None:
mu_phi.reference_uvw = [float(x) for x in params.split()]
mu_phi.mu = openmc.stats.Univariate.from_xml_element(elem.find('mu'))
mu_phi.phi = openmc.stats.Univariate.from_xml_element(elem.find('phi'))
mu_phi.mu = Univariate.from_xml_element(elem.find('mu'))
mu_phi.phi = Univariate.from_xml_element(elem.find('phi'))
return mu_phi
......@@ -263,7 +269,13 @@ class Spatial(metaclass=ABCMeta):
@classmethod
@abstractmethod
def from_xml_element(cls, elem):
pass
distribution = get_text(elem, 'type')
if distribution == 'cartesian':
return CartesianIndependent.from_xml_element(elem)
elif distribution == 'box' or distribution == 'fission':
return Box.from_xml_element(elem)
elif distribution == 'point':
return Point.from_xml_element(elem)
class CartesianIndependent(Spatial):
......@@ -357,9 +369,9 @@ class CartesianIndependent(Spatial):
Spatial distribution generated from XML element
"""
x = openmc.stats.Univariate.from_xml_element(elem.find('x'))
y = openmc.stats.Univariate.from_xml_element(elem.find('y'))
z = openmc.stats.Univariate.from_xml_element(elem.find('z'))
x = Univariate.from_xml_element(elem.find('x'))
y = Univariate.from_xml_element(elem.find('y'))
z = Univariate.from_xml_element(elem.find('z'))
return cls(x, y, z)
......
......@@ -36,7 +36,25 @@ class Univariate(EqualityMixin, metaclass=ABCMeta):
@classmethod
@abstractmethod
def from_xml_element(cls, elem):
pass
distribution = get_text(elem, 'type')
if distribution == 'discrete':
return Discrete.from_xml_element(elem)
elif distribution == 'uniform':
return Uniform.from_xml_element(elem)
elif distribution == 'maxwell':
return Maxwell.from_xml_element(elem)
elif distribution == 'watt':
return Watt.from_xml_element(elem)
elif distribution == 'normal':
return Normal.from_xml_element(elem)
elif distribution == 'muir':
return Muir.from_xml_element(elem)
elif distribution == 'tabular':
return Tabular.from_xml_element(elem)
elif distribution == 'legendre':
return Legendre.from_xml_element(elem)
elif distribution == 'mixture':
return Mixture.from_xml_element(elem)
class Discrete(Univariate):
......@@ -223,9 +241,7 @@ class Uniform(Univariate):
"""
params = get_text(elem, 'parameters').split()
a = float(params[0])
b = float(params[1])
return cls(a, b)
return cls(*map(float, params))
class Maxwell(Univariate):
......@@ -388,9 +404,7 @@ class Watt(Univariate):
"""
params = get_text(elem, 'parameters').split()
a = float(params[0])
b = float(params[1])
return watt(a, b)
return cls(*map(float, params))
class Normal(Univariate):
......@@ -478,9 +492,7 @@ class Normal(Univariate):
"""
params = get_text(elem, 'parameters').split()
mean_value = float(params[0])
std_dev = float(params[1])
return cls(mean_value, std_dev)
return cls(*map(float, params))
class Muir(Univariate):
......@@ -587,10 +599,7 @@ class Muir(Univariate):
"""
params = get_text(elem, 'parameters').split()
e0 = float(params[0])
m_rat = float(params[1])
kt = float(params[2])
return muir(e0, m_rat, kt)
return cls(*map(float, params))
class Tabular(Univariate):
......@@ -706,7 +715,7 @@ class Tabular(Univariate):
interpolation = get_text(elem, 'interpolation')
params = [float(x) for x in get_text(elem, 'parameters').split()]
x = params[:len(params)//2]
p = paramx[len(params)//2:]
p = params[len(params)//2:]
return cls(x, p, interpolation)
......
......@@ -11,7 +11,6 @@ def test_source():
assert src.space == space
assert src.angle == angle
assert src.energy == energy
assert src.strength == 1.0
elem = src.to_xml_element()
assert 'strength' in elem.attrib
......@@ -19,6 +18,13 @@ def test_source():
assert elem.find('angle') is not None
assert elem.find('energy') is not None
src = openmc.Source.from_xml_element(elem)
assert isinstance(src.angle, openmc.stats.Isotropic)
assert src.space.xyz == [0.0, 0.0, 0.0]
assert src.energy.x == [1.0e6]
assert src.energy.p == [1.0]
assert src.strength == 1.0
def test_source_file():
filename = 'source.h5'
......
......@@ -10,10 +10,15 @@ def test_discrete():
x = [0.0, 1.0, 10.0]
p = [0.3, 0.2, 0.5]
d = openmc.stats.Discrete(x, p)
elem = d.to_xml_element('distribution')
d = openmc.stats.Discrete.from_xml_element(elem)
assert d.x == x
assert d.p == p
assert len(d) == len(x)
d.to_xml_element('distribution')
d = openmc.stats.Univariate.from_xml_element(elem)
assert isinstance(d, openmc.stats.Discrete)
# Single point
d2 = openmc.stats.Discrete(1e6, 1.0)
......@@ -25,6 +30,9 @@ def test_discrete():
def test_uniform():
a, b = 10.0, 20.0
d = openmc.stats.Uniform(a, b)
elem = d.to_xml_element('distribution')
d = openmc.stats.Uniform.from_xml_element(elem)
assert d.a == a
assert d.b == b
assert len(d) == 2
......@@ -34,35 +42,39 @@ def test_uniform():
assert t.p == [1/(b-a), 1/(b-a)]
assert t.interpolation == 'histogram'
d.to_xml_element('distribution')
def test_maxwell():
theta = 1.2895e6
d = openmc.stats.Maxwell(theta)
elem = d.to_xml_element('distribution')
d = openmc.stats.Maxwell.from_xml_element(elem)
assert d.theta == theta
assert len(d) == 1
d.to_xml_element('distribution')
def test_watt():
a, b = 0.965e6, 2.29e-6
d = openmc.stats.Watt(a, b)
elem = d.to_xml_element('distribution')
d = openmc.stats.Watt.from_xml_element(elem)
assert d.a == a
assert d.b == b
assert len(d) == 2
d.to_xml_element('distribution')
def test_tabular():
x = [0.0, 5.0, 7.0]
p = [0.1, 0.2, 0.05]
d = openmc.stats.Tabular(x, p, 'linear-linear')
elem = d.to_xml_element('distribution')
d = openmc.stats.Tabular.from_xml_element(elem)
assert d.x == x
assert d.p == p
assert d.interpolation == 'linear-linear'
assert len(d) == len(x)
d.to_xml_element('distribution')
def test_legendre():
......@@ -115,6 +127,15 @@ def test_polar_azimuthal():
assert elem.find('mu') is not None
assert elem.find('phi') is not None
d = openmc.stats.PolarAzimuthal.from_xml_element(elem)
assert d.mu.x == [1.]
assert d.mu.p == [1.]
assert d.phi.x == [0.]
assert d.phi.p == [1.]
d = openmc.stats.UnitSphere.from_xml_element(elem)
assert isinstance(d, openmc.stats.PolarAzimuthal)
def test_isotropic():
d = openmc.stats.Isotropic()
......@@ -122,24 +143,25 @@ def test_isotropic():
assert elem.tag == 'angle'
assert elem.attrib['type'] == 'isotropic'
d = openmc.stats.Isotropic.from_xml_element(elem)
assert isinstance(d, openmc.stats.Isotropic)
def test_monodirectional():
d = openmc.stats.Monodirectional((1., 0., 0.))
assert d.reference_uvw == pytest.approx((1., 0., 0.))
elem = d.to_xml_element()
assert elem.tag == 'angle'
assert elem.attrib['type'] == 'monodirectional'
d = openmc.stats.Monodirectional.from_xml_element(elem)
assert d.reference_uvw == pytest.approx((1., 0., 0.))
def test_cartesian():
x = openmc.stats.Uniform(-10., 10.)
y = openmc.stats.Uniform(-10., 10.)
z = openmc.stats.Uniform(0., 20.)
d = openmc.stats.CartesianIndependent(x, y, z)
assert d.x == x
assert d.y == y
assert d.z == z
elem = d.to_xml_element()
assert elem.tag == 'space'
......@@ -147,55 +169,75 @@ def test_cartesian():
assert elem.find('x') is not None
assert elem.find('y') is not None
d = openmc.stats.CartesianIndependent.from_xml_element(elem)
assert d.x == x
assert d.y == y
assert d.z == z
d = openmc.stats.Spatial.from_xml_element(elem)
assert isinstance(d, openmc.stats.CartesianIndependent)
def test_box():
lower_left = (-10., -10., -10.)
upper_right = (10., 10., 10.)
d = openmc.stats.Box(lower_left, upper_right)
assert d.lower_left == pytest.approx(lower_left)
assert d.upper_right == pytest.approx(upper_right)
assert not d.only_fissionable
elem = d.to_xml_element()
assert elem.tag == 'space'
assert elem.attrib['type'] == 'box'
assert elem.find('parameters') is not None
d = openmc.stats.Box.from_xml_element(elem)
assert d.lower_left == pytest.approx(lower_left)
assert d.upper_right == pytest.approx(upper_right)
assert not d.only_fissionable
# only fissionable parameter
d2 = openmc.stats.Box(lower_left, upper_right, True)
assert d2.only_fissionable
elem = d2.to_xml_element()
assert elem.attrib['type'] == 'fission'
d = openmc.stats.Spatial.from_xml_element(elem)
assert isinstance(d, openmc.stats.Box)
def test_point():
p = (-4., 2., 10.)
d = openmc.stats.Point(p)
assert d.xyz == pytest.approx(p)
elem = d.to_xml_element()
assert elem.tag == 'space'
assert elem.attrib['type'] == 'point'
assert elem.find('parameters') is not None
d = openmc.stats.Point.from_xml_element(elem)
assert d.xyz == pytest.approx(p)
def test_normal():
mean = 10.0
std_dev = 2.0
d = openmc.stats.Normal(mean,std_dev)
elem = d.to_xml_element('distribution')
assert elem.attrib['type'] == 'normal'
d = openmc.stats.Normal.from_xml_element(elem)
assert d.mean_value == pytest.approx(mean)
assert d.std_dev == pytest.approx(std_dev)
assert len(d) == 2
elem = d.to_xml_element('distribution')
assert elem.attrib['type'] == 'normal'
def test_muir():
mean = 10.0
mass = 5.0
temp = 20000.
d = openmc.stats.Muir(mean,mass,temp)
elem = d.to_xml_element('energy')
assert elem.attrib['type'] == 'muir'
d = openmc.stats.Muir.from_xml_element(elem)
assert d.e0 == pytest.approx(mean)
assert d.m_rat == pytest.approx(mass)
assert d.kt == pytest.approx(temp)
assert len(d) == 3
elem = d.to_xml_element('energy')
assert elem.attrib['type'] == 'muir'
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment